whisper : token-level timestamp refactoring (#49, #120)

This turned out pretty good overall. The algorithm has been moved from
main.cpp to whisper.cpp and can be reused for all subtitles types. This
means that now you can specify the maximum length of the generated
lines. Simply provide the "-ml" argument specifying the max length in
number of characters
This commit is contained in:
Georgi Gerganov
2022-11-02 21:18:20 +02:00
parent 4b1c32e8ea
commit d5afebd37c
5 changed files with 518 additions and 370 deletions

View File

@@ -418,6 +418,12 @@ struct whisper_context {
std::vector<whisper_segment> result_all;
std::vector<whisper_token> prompt_past;
// [EXPERIMENTAL] token-level timestamps data
int64_t t_beg;
int64_t t_last;
whisper_token tid_last;
std::vector<float> energy; // PCM signal energy
};
// load the model from a ggml file
@@ -431,7 +437,7 @@ struct whisper_context {
//
// see the convert-pt-to-ggml.py script for details
//
bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
static bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
fprintf(stderr, "%s: loading model from '%s'\n", __func__, fname.c_str());
auto & model = wctx.model;
@@ -1062,7 +1068,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
// - n_threads: number of threads to use
// - mel_offset: offset in the mel spectrogram (i.e. audio offset)
//
bool whisper_encode(
static bool whisper_encode(
whisper_context & wctx,
const int n_threads,
const int mel_offset) {
@@ -1448,7 +1454,7 @@ bool whisper_encode(
// - n_tokens: number of tokens in the prompt
// - n_past: number of past tokens to prefix the prompt with
//
bool whisper_decode(
static bool whisper_decode(
whisper_context & wctx,
const int n_threads,
const whisper_token * tokens,
@@ -1811,10 +1817,12 @@ bool whisper_decode(
}
// the most basic sampling scheme - select the top token
whisper_token_data whisper_sample_best(
static whisper_token_data whisper_sample_best(
const whisper_vocab & vocab,
const float * probs) {
whisper_token_data result;
whisper_token_data result = {
0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
};
int n_logits = vocab.id_to_token.size();
@@ -1887,7 +1895,7 @@ whisper_token_data whisper_sample_best(
}
// samples only from the timestamps tokens
whisper_vocab::id whisper_sample_timestamp(
static whisper_vocab::id whisper_sample_timestamp(
const whisper_vocab & vocab,
const float * probs) {
int n_logits = vocab.id_to_token.size();
@@ -1939,7 +1947,7 @@ static std::string to_timestamp(int64_t t, bool comma = false) {
// naive Discrete Fourier Transform
// input is real-valued
// output is complex-valued
void dft(const std::vector<float> & in, std::vector<float> & out) {
static void dft(const std::vector<float> & in, std::vector<float> & out) {
int N = in.size();
out.resize(N*2);
@@ -1963,7 +1971,7 @@ void dft(const std::vector<float> & in, std::vector<float> & out) {
// poor man's implementation - use something better
// input is real-valued
// output is complex-valued
void fft(const std::vector<float> & in, std::vector<float> & out) {
static void fft(const std::vector<float> & in, std::vector<float> & out) {
out.resize(in.size()*2);
int N = in.size();
@@ -2014,7 +2022,7 @@ void fft(const std::vector<float> & in, std::vector<float> & out) {
}
// ref: https://github.com/openai/whisper/blob/main/whisper/audio.py#L92-L124
bool log_mel_spectrogram(
static bool log_mel_spectrogram(
const float * samples,
const int n_samples,
const int sample_rate,
@@ -2339,6 +2347,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.print_realtime =*/ false,
/*.print_timestamps =*/ true,
/*.token_timestamps =*/ false,
/*.thold_pt =*/ 0.01f,
/*.thold_ptsum =*/ 0.01f,
/*.max_len =*/ 0,
/*.language =*/ "en",
/*.greedy =*/ {
@@ -2371,6 +2384,11 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
/*.print_realtime =*/ false,
/*.print_timestamps =*/ true,
/*.token_timestamps =*/ false,
/*.thold_pt =*/ 0.01f,
/*.thold_ptsum =*/ 0.01f,
/*.max_len =*/ 0,
/*.language =*/ "en",
/*.greedy =*/ {
@@ -2392,6 +2410,68 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str
return result;
}
// forward declarations
static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window);
static void whisper_exp_compute_token_level_timestamps(
struct whisper_context * ctx,
int i_segment,
float thold_pt,
float thold_ptsum);
// wrap the last segment to max_len characters
// returns the number of new segments
static int whisper_wrap_segment(struct whisper_context * ctx, int max_len) {
auto segment = ctx->result_all.back();
int res = 1;
int acc = 0;
std::string text;
for (int i = 0; i < (int) segment.tokens.size(); i++) {
const auto & token = segment.tokens[i];
if (token.id >= whisper_token_eot(ctx)) {
continue;
}
const auto txt = whisper_token_to_str(ctx, token.id);
const int cur = strlen(txt);
if (acc + cur > max_len && i > 0) {
// split here
ctx->result_all.back().text = std::move(text);
ctx->result_all.back().t1 = token.t0;
ctx->result_all.back().tokens.resize(i);
ctx->result_all.push_back({});
ctx->result_all.back().t0 = token.t0;
ctx->result_all.back().t1 = segment.t1;
// add tokens [i, end] to the new segment
ctx->result_all.back().tokens.insert(
ctx->result_all.back().tokens.end(),
segment.tokens.begin() + i,
segment.tokens.end());
acc = 0;
text = "";
segment = ctx->result_all.back();
i = -1;
res++;
} else {
acc += cur;
text += txt;
}
}
ctx->result_all.back().text = std::move(text);
return res;
}
int whisper_full(
struct whisper_context * ctx,
struct whisper_full_params params,
@@ -2408,6 +2488,13 @@ int whisper_full(
return -1;
}
if (params.token_timestamps) {
ctx->t_beg = 0;
ctx->t_last = 0;
ctx->tid_last = 0;
ctx->energy = get_signal_energy(samples, n_samples, 32);
}
const int seek_start = params.offset_ms/10;
// if length of spectrogram is less than 1s (100 samples), then return
@@ -2557,6 +2644,7 @@ int whisper_full(
}
}
// shrink down to result_len
tokens_cur.resize(result_len);
for (const auto & r : tokens_cur) {
@@ -2595,8 +2683,19 @@ int whisper_full(
for (int j = i0; j <= i; j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
int n_new = 1;
if (params.token_timestamps) {
whisper_exp_compute_token_level_timestamps(
ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
if (params.max_len > 0) {
n_new = whisper_wrap_segment(ctx, params.max_len);
}
}
if (params.new_segment_callback) {
params.new_segment_callback(ctx, params.new_segment_callback_user_data);
params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
}
}
text = "";
@@ -2625,8 +2724,19 @@ int whisper_full(
for (int j = i0; j < (int) tokens_cur.size(); j++) {
result_all.back().tokens.push_back(tokens_cur[j]);
}
int n_new = 1;
if (params.token_timestamps) {
whisper_exp_compute_token_level_timestamps(
ctx, result_all.size() - 1, params.thold_pt, params.thold_ptsum);
if (params.max_len > 0) {
n_new = whisper_wrap_segment(ctx, params.max_len);
}
}
if (params.new_segment_callback) {
params.new_segment_callback(ctx, params.new_segment_callback_user_data);
params.new_segment_callback(ctx, n_new, params.new_segment_callback_user_data);
}
}
}
@@ -2760,7 +2870,7 @@ int whisper_full_parallel(
// call the new_segment_callback for each segment
if (params.new_segment_callback) {
params.new_segment_callback(ctx, params.new_segment_callback_user_data);
params.new_segment_callback(ctx, 1, params.new_segment_callback_user_data);
}
}
@@ -2836,3 +2946,304 @@ const char * whisper_print_system_info() {
return s.c_str();
}
// =================================================================================================
//
// Experimental stuff below
//
// Not sure if these should be part of the library at all, because the quality of the results is not
// guaranteed. Might get removed at some point unless a robust algorithm implementation is found
//
// =================================================================================================
//
// token-level timestamps
//
static int timestamp_to_sample(int64_t t, int n_samples) {
return std::max(0, std::min((int) n_samples - 1, (int) ((t*WHISPER_SAMPLE_RATE)/100)));
}
static int64_t sample_to_timestamp(int i_sample) {
return (100*i_sample)/WHISPER_SAMPLE_RATE;
}
// a cost-function / heuristic that is high for text that takes longer to pronounce
// obviously, can be improved
static float voice_length(const std::string & text) {
float res = 0.0f;
for (size_t i = 0; i < text.size(); ++i) {
if (text[i] == ' ') {
res += 0.01f;
} else if (text[i] == ',') {
res += 2.00f;
} else if (text[i] == '.') {
res += 3.00f;
} else if (text[i] == '!') {
res += 3.00f;
} else if (text[i] == '?') {
res += 3.00f;
} else if (text[i] >= '0' && text[i] <= '9') {
res += 3.00f;
} else {
res += 1.00f;
}
}
return res;
}
// average the fabs of the signal
static std::vector<float> get_signal_energy(const float * signal, int n_samples, int n_samples_per_half_window) {
const int hw = n_samples_per_half_window;
std::vector<float> result(n_samples);
for (int i = 0; i < n_samples; i++) {
float sum = 0;
for (int j = -hw; j <= hw; j++) {
if (i + j >= 0 && i + j < n_samples) {
sum += fabs(signal[i + j]);
}
}
result[i] = sum/(2*hw + 1);
}
return result;
}
static void whisper_exp_compute_token_level_timestamps(
struct whisper_context * ctx,
int i_segment,
float thold_pt,
float thold_ptsum) {
auto & segment = ctx->result_all[i_segment];
auto & tokens = segment.tokens;
const int n_samples = ctx->energy.size();
if (n_samples == 0) {
fprintf(stderr, "%s: no signal data available\n", __func__);
return;
}
const int64_t t0 = segment.t0;
const int64_t t1 = segment.t1;
const int s0 = timestamp_to_sample(t0, n_samples);
const int s1 = timestamp_to_sample(t1, n_samples);
const int n = tokens.size();
if (n == 0) {
return;
}
if (n == 1) {
tokens[0].t0 = t0;
tokens[0].t1 = t1;
return;
}
auto & t_beg = ctx->t_beg;
auto & t_last = ctx->t_last;
auto & tid_last = ctx->tid_last;
for (int j = 0; j < n; ++j) {
auto & token = tokens[j];
if (j == 0) {
if (token.id == whisper_token_beg(ctx)) {
tokens[j ].t0 = t0;
tokens[j ].t1 = t0;
tokens[j + 1].t0 = t0;
t_beg = t0;
t_last = t0;
tid_last = whisper_token_beg(ctx);
} else {
tokens[j ].t0 = t_last;
}
}
const int64_t tt = t_beg + 2*(token.tid - whisper_token_beg(ctx));
tokens[j].id = token.id;
tokens[j].tid = token.tid;
tokens[j].p = token.p;
tokens[j].pt = token.pt;
tokens[j].ptsum = token.ptsum;
tokens[j].vlen = voice_length(whisper_token_to_str(ctx, token.id));
if (token.pt > thold_pt && token.ptsum > thold_ptsum && token.tid > tid_last && tt <= t1) {
if (j > 0) {
tokens[j - 1].t1 = tt;
}
tokens[j].t0 = tt;
tid_last = token.tid;
}
}
tokens[n - 2].t1 = t1;
tokens[n - 1].t0 = t1;
tokens[n - 1].t1 = t1;
t_last = t1;
// find intervals of tokens with unknown timestamps
// fill the timestamps by proportionally splitting the interval based on the token voice lengths
{
int p0 = 0;
int p1 = 0;
while (true) {
while (p1 < n && tokens[p1].t1 < 0) {
p1++;
}
if (p1 >= n) {
p1--;
}
if (p1 > p0) {
double psum = 0.0;
for (int j = p0; j <= p1; j++) {
psum += tokens[j].vlen;
}
//printf("analyzing %d - %d, psum = %f\n", p0, p1, psum);
const double dt = tokens[p1].t1 - tokens[p0].t0;
// split the time proportionally to the voice length
for (int j = p0 + 1; j <= p1; j++) {
const double ct = tokens[j - 1].t0 + dt*tokens[j - 1].vlen/psum;
tokens[j - 1].t1 = ct;
tokens[j ].t0 = ct;
}
}
p1++;
p0 = p1;
if (p1 >= n) {
break;
}
}
}
// fix up (just in case)
for (int j = 0; j < n - 1; j++) {
if (tokens[j].t1 < 0) {
tokens[j + 1].t0 = tokens[j].t1;
}
if (j > 0) {
if (tokens[j - 1].t1 > tokens[j].t0) {
tokens[j].t0 = tokens[j - 1].t1;
tokens[j].t1 = std::max(tokens[j].t0, tokens[j].t1);
}
}
}
// VAD
// expand or contract tokens based on voice activity
{
const int hw = WHISPER_SAMPLE_RATE/8;
for (int j = 0; j < n; j++) {
if (tokens[j].id >= whisper_token_eot(ctx)) {
continue;
}
int s0 = timestamp_to_sample(tokens[j].t0, n_samples);
int s1 = timestamp_to_sample(tokens[j].t1, n_samples);
const int ss0 = std::max(s0 - hw, 0);
const int ss1 = std::min(s1 + hw, n_samples);
const int ns = ss1 - ss0;
float sum = 0.0f;
for (int k = ss0; k < ss1; k++) {
sum += ctx->energy[k];
}
const float thold = 0.5*sum/ns;
{
int k = s0;
if (ctx->energy[k] > thold && j > 0) {
while (k > 0 && ctx->energy[k] > thold) {
k--;
}
tokens[j].t0 = sample_to_timestamp(k);
if (tokens[j].t0 < tokens[j - 1].t1) {
tokens[j].t0 = tokens[j - 1].t1;
} else {
s0 = k;
}
} else {
while (ctx->energy[k] < thold && k < s1) {
k++;
}
s0 = k;
tokens[j].t0 = sample_to_timestamp(k);
}
}
{
int k = s1;
if (ctx->energy[k] > thold) {
while (k < n_samples - 1 && ctx->energy[k] > thold) {
k++;
}
tokens[j].t1 = sample_to_timestamp(k);
if (j < ns - 1 && tokens[j].t1 > tokens[j + 1].t0) {
tokens[j].t1 = tokens[j + 1].t0;
} else {
s1 = k;
}
} else {
while (ctx->energy[k] < thold && k > s0) {
k--;
}
s1 = k;
tokens[j].t1 = sample_to_timestamp(k);
}
}
}
}
// fixed token expand (optional)
//{
// const int t_expand = 0;
// for (int j = 0; j < n; j++) {
// if (j > 0) {
// tokens[j].t0 = std::max(0, (int) (tokens[j].t0 - t_expand));
// }
// if (j < n - 1) {
// tokens[j].t1 = tokens[j].t1 + t_expand;
// }
// }
//}
// debug info
//for (int j = 0; j < n; ++j) {
// const auto & token = tokens[j];
// const auto tt = token.pt > thold_pt && token.ptsum > 0.01 ? whisper_token_to_str(ctx, token.tid) : "[?]";
// printf("%s: %10s %6.3f %6.3f %6.3f %6.3f %5d %5d '%s'\n", __func__,
// tt, token.p, token.pt, token.ptsum, token.vlen, (int) token.t0, (int) token.t1, whisper_token_to_str(ctx, token.id));
// if (tokens[j].id >= whisper_token_eot(ctx)) {
// continue;
// }
//}
}