Combine transcribe endpoints

This commit is contained in:
Ahmet Öner
2022-11-07 00:11:58 +01:00
parent 1dc7f476bc
commit 285e555b36

View File

@@ -1,6 +1,6 @@
import uvicorn
from fastapi import FastAPI, File, UploadFile, Query
from fastapi.responses import StreamingResponse
from fastapi import FastAPI, File, UploadFile, Query, Response
from fastapi.responses import StreamingResponse, RedirectResponse
import whisper
from whisper.utils import write_srt, write_vtt
import os
@@ -14,7 +14,18 @@ import torch
SAMPLE_RATE=16000
app = FastAPI()
app = FastAPI(
title="Webservice API",
description="OpenAI Whisper ASR Webservice API",
contact={
"url": "https://github.com/ahmetoner/whisper-asr-webservice/",
},
swagger_ui_parameters={"defaultModelsExpandDepth": -1},
license_info={
"name": "MIT License",
"url": "https://github.com/ahmetoner/whisper-asr-webservice/blob/main/LICENCE",
},
)
model_name= os.getenv("ASR_MODEL", "base")
@@ -25,19 +36,37 @@ else:
model_lock = Lock()
@app.post("/asr")
def transcribe_file(
@app.get("/", response_class=RedirectResponse, include_in_schema=False)
async def index():
return "/docs"
@app.post("/asr", tags=["Endpoints"])
def transcribe(
audio_file: UploadFile = File(...),
task : Union[str, None] = Query(default="transcribe", enum=["transcribe", "translate"]),
language: Union[str, None] = Query(default=None, enum=LANGUAGE_CODES),
output : Union[str, None] = Query(default="json", enum=["json", "vtt", "srt"]),
):
result = run_asr(audio_file.file, task, language)
return result
filename = audio_file.filename.split('.')[0]
if(output == "srt"):
srt_file = StringIO()
write_srt(result["segments"], file = srt_file)
srt_file.seek(0)
return StreamingResponse(srt_file, media_type="text/plain",
headers={'Content-Disposition': f'attachment; filename="{filename}.srt"'})
elif(output == "vtt"):
vtt_file = StringIO()
write_vtt(result["segments"], file = vtt_file)
vtt_file.seek(0)
return StreamingResponse(vtt_file, media_type="text/plain",
headers={'Content-Disposition': f'attachment; filename="{filename}.vtt"'})
else:
return result
@app.post("/detect-language")
@app.post("/detect-language", tags=["Endpoints"])
def language_detection(
audio_file: UploadFile = File(...),
):
@@ -60,40 +89,6 @@ def language_detection(
return result
@app.post("/get-srt", response_class=StreamingResponse)
def transcribe_file2srt(
audio_file: UploadFile = File(...),
task : Union[str, None] = Query(default="transcribe", enum=["transcribe", "translate"]),
language: Union[str, None] = Query(default=None, enum=LANGUAGE_CODES),
):
result = run_asr(audio_file.file, task, language)
srt_file = StringIO()
write_srt(result["segments"], file = srt_file)
srt_file.seek(0)
srt_filename = f"{audio_file.filename.split('.')[0]}.srt"
return StreamingResponse(srt_file, media_type="text/plain",
headers={'Content-Disposition': f'attachment; filename="{srt_filename}"'})
@app.post("/get-vtt", response_class=StreamingResponse)
def transcribe_file2vtt(
audio_file: UploadFile = File(...),
task : Union[str, None] = Query(default="transcribe", enum=["transcribe", "translate"]),
language: Union[str, None] = Query(default=None, enum=LANGUAGE_CODES),
):
result = run_asr(audio_file.file, task, language)
vtt_file = StringIO()
write_vtt(result["segments"], file = vtt_file)
vtt_file.seek(0)
vtt_filename = f"{audio_file.filename.split('.')[0]}.vtt"
return StreamingResponse(vtt_file, media_type="text/plain",
headers={'Content-Disposition': f'attachment; filename="{vtt_filename}"'})
def run_asr(file: BinaryIO, task: Union[str, None], language: Union[str, None] ):
audio = load_audio(file)
options_dict = {"task" : task}