mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2023-11-04 02:52:44 +03:00
go : improve progress reporting and callback handling (#1024)
- Rename `cb` to `callNewSegment` in the `Process` function - Add `callProgress` as a new parameter to the `Process` function - Introduce `ProgressCallback` type for reporting progress during processing - Update `Whisper_full` function to include `progressCallback` parameter - Add `registerProgressCallback` function and `cbProgress` map for handling progress callbacks Signed-off-by: appleboy <appleboy.tw@gmail.com>
This commit is contained in:
@@ -15,6 +15,7 @@ import (
|
||||
#include <stdlib.h>
|
||||
|
||||
extern void callNewSegment(void* user_data, int new);
|
||||
extern void callProgress(void* user_data, int progress);
|
||||
extern bool callEncoderBegin(void* user_data);
|
||||
|
||||
// Text segment callback
|
||||
@@ -26,6 +27,15 @@ static void whisper_new_segment_cb(struct whisper_context* ctx, struct whisper_s
|
||||
}
|
||||
}
|
||||
|
||||
// Progress callback
|
||||
// Called on every newly generated text segment
|
||||
// Use the whisper_full_...() functions to obtain the text segments
|
||||
static void whisper_progress_cb(struct whisper_context* ctx, struct whisper_state* state, int progress, void* user_data) {
|
||||
if(user_data != NULL && ctx != NULL) {
|
||||
callProgress(user_data, progress);
|
||||
}
|
||||
}
|
||||
|
||||
// Encoder begin callback
|
||||
// If not NULL, called before the encoder starts
|
||||
// If it returns false, the computation is aborted
|
||||
@@ -43,6 +53,8 @@ static struct whisper_full_params whisper_full_default_params_cb(struct whisper_
|
||||
params.new_segment_callback_user_data = (void*)(ctx);
|
||||
params.encoder_begin_callback = whisper_encoder_begin_cb;
|
||||
params.encoder_begin_callback_user_data = (void*)(ctx);
|
||||
params.progress_callback = whisper_progress_cb;
|
||||
params.progress_callback_user_data = (void*)(ctx);
|
||||
return params;
|
||||
}
|
||||
*/
|
||||
@@ -290,11 +302,19 @@ func (ctx *Context) Whisper_full_default_params(strategy SamplingStrategy) Param
|
||||
|
||||
// Run the entire model: PCM -> log mel spectrogram -> encoder -> decoder -> text
|
||||
// Uses the specified decoding strategy to obtain the text.
|
||||
func (ctx *Context) Whisper_full(params Params, samples []float32, encoderBeginCallback func() bool, newSegmentCallback func(int)) error {
|
||||
func (ctx *Context) Whisper_full(
|
||||
params Params,
|
||||
samples []float32,
|
||||
encoderBeginCallback func() bool,
|
||||
newSegmentCallback func(int),
|
||||
progressCallback func(int),
|
||||
) error {
|
||||
registerEncoderBeginCallback(ctx, encoderBeginCallback)
|
||||
registerNewSegmentCallback(ctx, newSegmentCallback)
|
||||
registerProgressCallback(ctx, progressCallback)
|
||||
defer registerEncoderBeginCallback(ctx, nil)
|
||||
defer registerNewSegmentCallback(ctx, nil)
|
||||
defer registerProgressCallback(ctx, nil)
|
||||
if C.whisper_full((*C.struct_whisper_context)(ctx), (C.struct_whisper_full_params)(params), (*C.float)(&samples[0]), C.int(len(samples))) == 0 {
|
||||
return nil
|
||||
} else {
|
||||
@@ -370,6 +390,7 @@ func (ctx *Context) Whisper_full_get_token_p(segment int, token int) float32 {
|
||||
|
||||
var (
|
||||
cbNewSegment = make(map[unsafe.Pointer]func(int))
|
||||
cbProgress = make(map[unsafe.Pointer]func(int))
|
||||
cbEncoderBegin = make(map[unsafe.Pointer]func() bool)
|
||||
)
|
||||
|
||||
@@ -381,6 +402,14 @@ func registerNewSegmentCallback(ctx *Context, fn func(int)) {
|
||||
}
|
||||
}
|
||||
|
||||
func registerProgressCallback(ctx *Context, fn func(int)) {
|
||||
if fn == nil {
|
||||
delete(cbProgress, unsafe.Pointer(ctx))
|
||||
} else {
|
||||
cbProgress[unsafe.Pointer(ctx)] = fn
|
||||
}
|
||||
}
|
||||
|
||||
func registerEncoderBeginCallback(ctx *Context, fn func() bool) {
|
||||
if fn == nil {
|
||||
delete(cbEncoderBegin, unsafe.Pointer(ctx))
|
||||
@@ -396,6 +425,13 @@ func callNewSegment(user_data unsafe.Pointer, new C.int) {
|
||||
}
|
||||
}
|
||||
|
||||
//export callProgress
|
||||
func callProgress(user_data unsafe.Pointer, progress C.int) {
|
||||
if fn, ok := cbProgress[user_data]; ok {
|
||||
fn(int(progress))
|
||||
}
|
||||
}
|
||||
|
||||
//export callEncoderBegin
|
||||
func callEncoderBegin(user_data unsafe.Pointer) C.bool {
|
||||
if fn, ok := cbEncoderBegin[user_data]; ok {
|
||||
|
||||
Reference in New Issue
Block a user