mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2023-11-04 02:52:44 +03:00
Improve decoding (#291)
* whisper : prepare infra for new decoding strategies * whisper : apply logit filters and compute logprobs * whisper : add whisper_get_logits() * whisper : separate self and cross attention memory Initial step needed for supporting parallel decoders * whisper : move probs_id buffer to whisper_context * whisper : refactor kv cache into separate struct * whisper : move self-attention kv cache to whisper_decoder * whisper : wip decoding parameters + strategies * whisper : wip decoding parameters + strategies (part 2) * whisper : wip decoding parameters + strategies (part 3) * whisper : wip decoding parameters + strategies (part 4) * whisper : fix prompt_past update to not include prompt_init * whisper : temperature + best_of support * whisper : support for compression_ration_threshold We actually use entropy, but it is similar * command : fix example to use logits instead of obsolete probs * whisper : handle empty sequence ranking * whisper : add WHISPER_DEBUG + diagnostic prints + new main args * whisper : minor fixes * whisper : add beam-search support * whisper : bug fix when there no previous context * whisper : add comments * stream : disable temperature fallback For real-time processing, we always want a single decoder running at T=0 * whisper.swiftui : update example - fix paths + add empty folders
This commit is contained in:
60
whisper.h
60
whisper.h
@@ -74,6 +74,7 @@ extern "C" {
|
||||
whisper_token tid; // forced timestamp token id
|
||||
|
||||
float p; // probability of the token
|
||||
float plog; // log probability of the token
|
||||
float pt; // probability of the timestamp token
|
||||
float ptsum; // sum of probabilities of all timestamp tokens
|
||||
|
||||
@@ -136,6 +137,7 @@ extern "C" {
|
||||
// tokens + n_tokens is the provided context for the decoder.
|
||||
// n_past is the number of tokens to use from previous decoder calls.
|
||||
// Returns 0 on success
|
||||
// TODO: add support for multiple decoders
|
||||
WHISPER_API int whisper_decode(
|
||||
struct whisper_context * ctx,
|
||||
const whisper_token * tokens,
|
||||
@@ -143,14 +145,6 @@ extern "C" {
|
||||
int n_past,
|
||||
int n_threads);
|
||||
|
||||
// Token sampling methods.
|
||||
// These are provided for convenience and can be used after each call to whisper_decode().
|
||||
// You can also implement your own sampling method using the whisper_get_probs() function.
|
||||
// whisper_sample_best() returns the token with the highest probability
|
||||
// whisper_sample_timestamp() returns the most probable timestamp token
|
||||
WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
|
||||
WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial);
|
||||
|
||||
// Convert the provided text into tokens.
|
||||
// The tokens pointer must be large enough to hold the resulting tokens.
|
||||
// Returns the number of tokens on success, no more than n_max_tokens
|
||||
@@ -192,8 +186,11 @@ extern "C" {
|
||||
WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
|
||||
WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
|
||||
|
||||
// The probabilities for the next token
|
||||
WHISPER_API float * whisper_get_probs(struct whisper_context * ctx);
|
||||
// Token logits obtained from the last call to whisper_decode()
|
||||
// The logits for the last token are stored in the last row
|
||||
// Rows: n_tokens
|
||||
// Cols: n_vocab
|
||||
WHISPER_API float * whisper_get_logits(struct whisper_context * ctx);
|
||||
|
||||
// Token Id -> String. Uses the vocabulary in the provided context
|
||||
WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token);
|
||||
@@ -222,8 +219,8 @@ extern "C" {
|
||||
|
||||
// Available sampling strategies
|
||||
enum whisper_sampling_strategy {
|
||||
WHISPER_SAMPLING_GREEDY, // Always select the most probable token
|
||||
WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet!
|
||||
WHISPER_SAMPLING_GREEDY, // similar to OpenAI's GreefyDecoder
|
||||
WHISPER_SAMPLING_BEAM_SEARCH, // similar to OpenAI's BeamSearchDecoder
|
||||
};
|
||||
|
||||
// Text segment callback
|
||||
@@ -243,17 +240,17 @@ extern "C" {
|
||||
enum whisper_sampling_strategy strategy;
|
||||
|
||||
int n_threads;
|
||||
int n_max_text_ctx;
|
||||
int n_max_text_ctx; // max tokens to use from past text as prompt for the decoder
|
||||
int offset_ms; // start offset in ms
|
||||
int duration_ms; // audio duration to process in ms
|
||||
|
||||
bool translate;
|
||||
bool no_context;
|
||||
bool no_context; // do not use initial prompt for the decoder (if any)
|
||||
bool single_segment; // force single segment output (useful for streaming)
|
||||
bool print_special;
|
||||
bool print_progress;
|
||||
bool print_realtime;
|
||||
bool print_timestamps;
|
||||
bool print_special; // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
|
||||
bool print_progress; // print progress information
|
||||
bool print_realtime; // print results from within whisper.cpp (avoid it, use callback instead)
|
||||
bool print_timestamps; // print timestamps for each text segment when printing realtime
|
||||
|
||||
// [EXPERIMENTAL] token-level timestamps
|
||||
bool token_timestamps; // enable token-level timestamps
|
||||
@@ -263,10 +260,11 @@ extern "C" {
|
||||
int max_tokens; // max tokens per segment (0 = no limit)
|
||||
|
||||
// [EXPERIMENTAL] speed-up techniques
|
||||
// note: these can significantly reduce the quality of the output
|
||||
bool speed_up; // speed-up the audio by 2x using Phase Vocoder
|
||||
int audio_ctx; // overwrite the audio context size (0 = use default)
|
||||
|
||||
// tokens to provide the whisper model as initial prompt
|
||||
// tokens to provide to the whisper decoder as initial prompt
|
||||
// these are prepended to any existing text context from a previous call
|
||||
const whisper_token * prompt_tokens;
|
||||
int prompt_n_tokens;
|
||||
@@ -274,19 +272,35 @@ extern "C" {
|
||||
// for auto-detection, set to nullptr, "" or "auto"
|
||||
const char * language;
|
||||
|
||||
// common decoding parameters:
|
||||
bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
|
||||
|
||||
float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478
|
||||
float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
|
||||
float length_penalty; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L267
|
||||
|
||||
// fallback parameters
|
||||
// ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L274-L278
|
||||
float temperature_inc;
|
||||
float entropy_thold; // similar to OpenAI's "compression_ratio_threshold"
|
||||
float logprob_thold;
|
||||
float no_speech_thold; // TODO: not implemented
|
||||
|
||||
struct {
|
||||
int n_past;
|
||||
int best_of; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264
|
||||
} greedy;
|
||||
|
||||
struct {
|
||||
int n_past;
|
||||
int beam_width;
|
||||
int n_best;
|
||||
int beam_size; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265
|
||||
|
||||
float patience; // TODO: not implemented, ref: https://arxiv.org/pdf/2204.05424.pdf
|
||||
} beam_search;
|
||||
|
||||
// called for every newly generated text segment
|
||||
whisper_new_segment_callback new_segment_callback;
|
||||
void * new_segment_callback_user_data;
|
||||
|
||||
// called each time before the encoder starts
|
||||
whisper_encoder_begin_callback encoder_begin_callback;
|
||||
void * encoder_begin_callback_user_data;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user