mirror of
https://github.com/ahmetoner/whisper-asr-webservice.git
synced 2023-04-14 03:48:29 +03:00
Combine transcribe endpoints
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, File, UploadFile, Query
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi import FastAPI, File, UploadFile, Query, Response
|
||||
from fastapi.responses import StreamingResponse, RedirectResponse
|
||||
import whisper
|
||||
from whisper.utils import write_srt, write_vtt
|
||||
import os
|
||||
@@ -14,7 +14,18 @@ import torch
|
||||
|
||||
SAMPLE_RATE=16000
|
||||
|
||||
app = FastAPI()
|
||||
app = FastAPI(
|
||||
title="Webservice API",
|
||||
description="OpenAI Whisper ASR Webservice API",
|
||||
contact={
|
||||
"url": "https://github.com/ahmetoner/whisper-asr-webservice/",
|
||||
},
|
||||
swagger_ui_parameters={"defaultModelsExpandDepth": -1},
|
||||
license_info={
|
||||
"name": "MIT License",
|
||||
"url": "https://github.com/ahmetoner/whisper-asr-webservice/blob/main/LICENCE",
|
||||
},
|
||||
)
|
||||
|
||||
model_name= os.getenv("ASR_MODEL", "base")
|
||||
|
||||
@@ -25,19 +36,37 @@ else:
|
||||
|
||||
model_lock = Lock()
|
||||
|
||||
@app.post("/asr")
|
||||
def transcribe_file(
|
||||
@app.get("/", response_class=RedirectResponse, include_in_schema=False)
|
||||
async def index():
|
||||
return "/docs"
|
||||
|
||||
@app.post("/asr", tags=["Endpoints"])
|
||||
def transcribe(
|
||||
audio_file: UploadFile = File(...),
|
||||
task : Union[str, None] = Query(default="transcribe", enum=["transcribe", "translate"]),
|
||||
language: Union[str, None] = Query(default=None, enum=LANGUAGE_CODES),
|
||||
output : Union[str, None] = Query(default="json", enum=["json", "vtt", "srt"]),
|
||||
):
|
||||
|
||||
result = run_asr(audio_file.file, task, language)
|
||||
|
||||
return result
|
||||
filename = audio_file.filename.split('.')[0]
|
||||
if(output == "srt"):
|
||||
srt_file = StringIO()
|
||||
write_srt(result["segments"], file = srt_file)
|
||||
srt_file.seek(0)
|
||||
return StreamingResponse(srt_file, media_type="text/plain",
|
||||
headers={'Content-Disposition': f'attachment; filename="{filename}.srt"'})
|
||||
elif(output == "vtt"):
|
||||
vtt_file = StringIO()
|
||||
write_vtt(result["segments"], file = vtt_file)
|
||||
vtt_file.seek(0)
|
||||
return StreamingResponse(vtt_file, media_type="text/plain",
|
||||
headers={'Content-Disposition': f'attachment; filename="{filename}.vtt"'})
|
||||
else:
|
||||
return result
|
||||
|
||||
|
||||
@app.post("/detect-language")
|
||||
@app.post("/detect-language", tags=["Endpoints"])
|
||||
def language_detection(
|
||||
audio_file: UploadFile = File(...),
|
||||
):
|
||||
@@ -60,40 +89,6 @@ def language_detection(
|
||||
return result
|
||||
|
||||
|
||||
@app.post("/get-srt", response_class=StreamingResponse)
|
||||
def transcribe_file2srt(
|
||||
audio_file: UploadFile = File(...),
|
||||
task : Union[str, None] = Query(default="transcribe", enum=["transcribe", "translate"]),
|
||||
language: Union[str, None] = Query(default=None, enum=LANGUAGE_CODES),
|
||||
):
|
||||
|
||||
result = run_asr(audio_file.file, task, language)
|
||||
|
||||
srt_file = StringIO()
|
||||
write_srt(result["segments"], file = srt_file)
|
||||
srt_file.seek(0)
|
||||
srt_filename = f"{audio_file.filename.split('.')[0]}.srt"
|
||||
return StreamingResponse(srt_file, media_type="text/plain",
|
||||
headers={'Content-Disposition': f'attachment; filename="{srt_filename}"'})
|
||||
|
||||
|
||||
@app.post("/get-vtt", response_class=StreamingResponse)
|
||||
def transcribe_file2vtt(
|
||||
audio_file: UploadFile = File(...),
|
||||
task : Union[str, None] = Query(default="transcribe", enum=["transcribe", "translate"]),
|
||||
language: Union[str, None] = Query(default=None, enum=LANGUAGE_CODES),
|
||||
):
|
||||
|
||||
result = run_asr(audio_file.file, task, language)
|
||||
|
||||
vtt_file = StringIO()
|
||||
write_vtt(result["segments"], file = vtt_file)
|
||||
vtt_file.seek(0)
|
||||
vtt_filename = f"{audio_file.filename.split('.')[0]}.vtt"
|
||||
return StreamingResponse(vtt_file, media_type="text/plain",
|
||||
headers={'Content-Disposition': f'attachment; filename="{vtt_filename}"'})
|
||||
|
||||
|
||||
def run_asr(file: BinaryIO, task: Union[str, None], language: Union[str, None] ):
|
||||
audio = load_audio(file)
|
||||
options_dict = {"task" : task}
|
||||
|
||||
Reference in New Issue
Block a user