Prompt previous tokens for streaming (#163)

* feat: prompt previous tokens for streaming

I used a vector pointer instead of vector itself because it gave weird errors, and why not

* convert vector to use with C api

* feat: remove old refs, check for prompt size

* feat: use better way of getting the pointer
This commit is contained in:
M. Eren Akbiyik
2022-11-22 17:10:35 +01:00
committed by GitHub
parent 78116f8eda
commit 63ae03b8e0
3 changed files with 33 additions and 0 deletions

View File

@@ -2412,6 +2412,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.speed_up =*/ false,
/*.audio_ctx =*/ 0,
/*.prompt_tokens =*/ nullptr,
/*.prompt_n_tokens =*/ 0,
/*.language =*/ "en",
/*.greedy =*/ {
@@ -2455,6 +2458,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.speed_up =*/ false,
/*.audio_ctx =*/ 0,
/*.prompt_tokens =*/ nullptr,
/*.prompt_n_tokens =*/ 0,
/*.language =*/ "en",
/*.greedy =*/ {
@@ -2584,6 +2590,15 @@ int whisper_full(
prompt_past.clear();
}
// Prepend the prompt tokens to the prompt_past
if (params.prompt_tokens && params.prompt_n_tokens > 0) {
// Parse tokens from the pointer (it points to an std::vector)
for (int i = 0; i < params.prompt_n_tokens; i++) {
prompt_past.push_back(params.prompt_tokens[i]);
}
std::rotate(prompt_past.begin(), prompt_past.end() - params.prompt_n_tokens, prompt_past.end());
}
// overwrite audio_ctx
ctx->exp_n_audio_ctx = params.audio_ctx;