Improve decoding (#291)

* whisper : prepare infra for new decoding strategies

* whisper : apply logit filters and compute logprobs

* whisper : add whisper_get_logits()

* whisper : separate self and cross attention memory

Initial step needed for supporting parallel decoders

* whisper : move probs_id buffer to whisper_context

* whisper : refactor kv cache into separate struct

* whisper : move self-attention kv cache to whisper_decoder

* whisper : wip decoding parameters + strategies

* whisper : wip decoding parameters + strategies (part 2)

* whisper : wip decoding parameters + strategies (part 3)

* whisper : wip decoding parameters + strategies (part 4)

* whisper : fix prompt_past update to not include prompt_init

* whisper : temperature + best_of support

* whisper : support for compression_ration_threshold

We actually use entropy, but it is similar

* command : fix example to use logits instead of obsolete probs

* whisper : handle empty sequence ranking

* whisper : add WHISPER_DEBUG + diagnostic prints + new main args

* whisper : minor fixes

* whisper : add beam-search support

* whisper : bug fix when there no previous context

* whisper : add comments

* stream : disable temperature fallback

For real-time processing, we always want a single decoder running at T=0

* whisper.swiftui : update example - fix paths + add empty folders
This commit is contained in:
Georgi Gerganov
2023-01-15 11:29:57 +02:00
committed by GitHub
parent a6dbd9188b
commit 8de452c18b
11 changed files with 1531 additions and 784 deletions

View File

@@ -74,6 +74,7 @@ extern "C" {
whisper_token tid; // forced timestamp token id
float p; // probability of the token
float plog; // log probability of the token
float pt; // probability of the timestamp token
float ptsum; // sum of probabilities of all timestamp tokens
@@ -136,6 +137,7 @@ extern "C" {
// tokens + n_tokens is the provided context for the decoder.
// n_past is the number of tokens to use from previous decoder calls.
// Returns 0 on success
// TODO: add support for multiple decoders
WHISPER_API int whisper_decode(
struct whisper_context * ctx,
const whisper_token * tokens,
@@ -143,14 +145,6 @@ extern "C" {
int n_past,
int n_threads);
// Token sampling methods.
// These are provided for convenience and can be used after each call to whisper_decode().
// You can also implement your own sampling method using the whisper_get_probs() function.
// whisper_sample_best() returns the token with the highest probability
// whisper_sample_timestamp() returns the most probable timestamp token
WHISPER_API whisper_token_data whisper_sample_best(struct whisper_context * ctx);
WHISPER_API whisper_token_data whisper_sample_timestamp(struct whisper_context * ctx, bool is_initial);
// Convert the provided text into tokens.
// The tokens pointer must be large enough to hold the resulting tokens.
// Returns the number of tokens on success, no more than n_max_tokens
@@ -192,8 +186,11 @@ extern "C" {
WHISPER_API int whisper_n_audio_ctx (struct whisper_context * ctx);
WHISPER_API int whisper_is_multilingual(struct whisper_context * ctx);
// The probabilities for the next token
WHISPER_API float * whisper_get_probs(struct whisper_context * ctx);
// Token logits obtained from the last call to whisper_decode()
// The logits for the last token are stored in the last row
// Rows: n_tokens
// Cols: n_vocab
WHISPER_API float * whisper_get_logits(struct whisper_context * ctx);
// Token Id -> String. Uses the vocabulary in the provided context
WHISPER_API const char * whisper_token_to_str(struct whisper_context * ctx, whisper_token token);
@@ -222,8 +219,8 @@ extern "C" {
// Available sampling strategies
enum whisper_sampling_strategy {
WHISPER_SAMPLING_GREEDY, // Always select the most probable token
WHISPER_SAMPLING_BEAM_SEARCH, // TODO: not implemented yet!
WHISPER_SAMPLING_GREEDY, // similar to OpenAI's GreefyDecoder
WHISPER_SAMPLING_BEAM_SEARCH, // similar to OpenAI's BeamSearchDecoder
};
// Text segment callback
@@ -243,17 +240,17 @@ extern "C" {
enum whisper_sampling_strategy strategy;
int n_threads;
int n_max_text_ctx;
int n_max_text_ctx; // max tokens to use from past text as prompt for the decoder
int offset_ms; // start offset in ms
int duration_ms; // audio duration to process in ms
bool translate;
bool no_context;
bool no_context; // do not use initial prompt for the decoder (if any)
bool single_segment; // force single segment output (useful for streaming)
bool print_special;
bool print_progress;
bool print_realtime;
bool print_timestamps;
bool print_special; // print special tokens (e.g. <SOT>, <EOT>, <BEG>, etc.)
bool print_progress; // print progress information
bool print_realtime; // print results from within whisper.cpp (avoid it, use callback instead)
bool print_timestamps; // print timestamps for each text segment when printing realtime
// [EXPERIMENTAL] token-level timestamps
bool token_timestamps; // enable token-level timestamps
@@ -263,10 +260,11 @@ extern "C" {
int max_tokens; // max tokens per segment (0 = no limit)
// [EXPERIMENTAL] speed-up techniques
// note: these can significantly reduce the quality of the output
bool speed_up; // speed-up the audio by 2x using Phase Vocoder
int audio_ctx; // overwrite the audio context size (0 = use default)
// tokens to provide the whisper model as initial prompt
// tokens to provide to the whisper decoder as initial prompt
// these are prepended to any existing text context from a previous call
const whisper_token * prompt_tokens;
int prompt_n_tokens;
@@ -274,19 +272,35 @@ extern "C" {
// for auto-detection, set to nullptr, "" or "auto"
const char * language;
// common decoding parameters:
bool suppress_blank; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L89
float temperature; // initial decoding temperature, ref: https://ai.stackexchange.com/a/32478
float max_initial_ts; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/decoding.py#L97
float length_penalty; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L267
// fallback parameters
// ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L274-L278
float temperature_inc;
float entropy_thold; // similar to OpenAI's "compression_ratio_threshold"
float logprob_thold;
float no_speech_thold; // TODO: not implemented
struct {
int n_past;
int best_of; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L264
} greedy;
struct {
int n_past;
int beam_width;
int n_best;
int beam_size; // ref: https://github.com/openai/whisper/blob/f82bc59f5ea234d4b97fb2860842ed38519f7e65/whisper/transcribe.py#L265
float patience; // TODO: not implemented, ref: https://arxiv.org/pdf/2204.05424.pdf
} beam_search;
// called for every newly generated text segment
whisper_new_segment_callback new_segment_callback;
void * new_segment_callback_user_data;
// called each time before the encoder starts
whisper_encoder_begin_callback encoder_begin_callback;
void * encoder_begin_callback_user_data;
};