mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2023-11-04 02:52:44 +03:00
wip : experimental color coding of tokens based on probabilities
This commit is contained in:
45
main.cpp
45
main.cpp
@@ -5,12 +5,20 @@
|
||||
#define DR_WAV_IMPLEMENTATION
|
||||
#include "dr_wav.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <fstream>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
// Terminal color map. 10 colors grouped in ranges [0.0, 0.1, ..., 0.9]
|
||||
// Lowest is red, middle is yellow, highest is green.
|
||||
const std::vector<std::string> k_colors = {
|
||||
"\033[38;5;196m", "\033[38;5;202m", "\033[38;5;208m", "\033[38;5;214m", "\033[38;5;220m",
|
||||
"\033[38;5;226m", "\033[38;5;190m", "\033[38;5;154m", "\033[38;5;118m", "\033[38;5;82m",
|
||||
};
|
||||
|
||||
// 500 -> 00:05.000
|
||||
// 6000 -> 01:00.000
|
||||
std::string to_timestamp(int64_t t) {
|
||||
@@ -41,6 +49,7 @@ struct whisper_params {
|
||||
bool output_vtt = false;
|
||||
bool output_srt = false;
|
||||
bool print_special_tokens = false;
|
||||
bool print_colors = false;
|
||||
bool no_timestamps = false;
|
||||
|
||||
std::string language = "en";
|
||||
@@ -87,6 +96,8 @@ bool whisper_params_parse(int argc, char ** argv, whisper_params & params) {
|
||||
params.output_srt = true;
|
||||
} else if (arg == "-ps" || arg == "--print_special") {
|
||||
params.print_special_tokens = true;
|
||||
} else if (arg == "-pc" || arg == "--print_colors") {
|
||||
params.print_colors = true;
|
||||
} else if (arg == "-nt" || arg == "--no_timestamps") {
|
||||
params.no_timestamps = true;
|
||||
} else if (arg == "-m" || arg == "--model") {
|
||||
@@ -122,6 +133,7 @@ void whisper_print_usage(int argc, char ** argv, const whisper_params & params)
|
||||
fprintf(stderr, " -ovtt, --output-vtt output result in a vtt file\n");
|
||||
fprintf(stderr, " -osrt, --output-srt output result in a srt file\n");
|
||||
fprintf(stderr, " -ps, --print_special print special tokens\n");
|
||||
fprintf(stderr, " -pc, --print_colors print colors\n");
|
||||
fprintf(stderr, " -nt, --no_timestamps do not print timestamps\n");
|
||||
fprintf(stderr, " -l LANG, --language LANG spoken language (default: %s)\n", params.language.c_str());
|
||||
fprintf(stderr, " -m FNAME, --model FNAME model path (default: %s)\n", params.model.c_str());
|
||||
@@ -222,7 +234,7 @@ int main(int argc, char ** argv) {
|
||||
{
|
||||
whisper_full_params wparams = whisper_full_default_params(WHISPER_SAMPLING_GREEDY);
|
||||
|
||||
wparams.print_realtime = true;
|
||||
wparams.print_realtime = !params.print_colors;
|
||||
wparams.print_progress = false;
|
||||
wparams.print_timestamps = !params.no_timestamps;
|
||||
wparams.print_special_tokens = params.print_special_tokens;
|
||||
@@ -242,16 +254,34 @@ int main(int argc, char ** argv) {
|
||||
|
||||
const int n_segments = whisper_full_n_segments(ctx);
|
||||
for (int i = 0; i < n_segments; ++i) {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
|
||||
if (params.no_timestamps) {
|
||||
printf("%s", text);
|
||||
fflush(stdout);
|
||||
if (params.print_colors) {
|
||||
// TODO
|
||||
} else {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
printf("%s", text);
|
||||
fflush(stdout);
|
||||
}
|
||||
} else {
|
||||
const int64_t t0 = whisper_full_get_segment_t0(ctx, i);
|
||||
const int64_t t1 = whisper_full_get_segment_t1(ctx, i);
|
||||
|
||||
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
|
||||
if (params.print_colors) {
|
||||
printf("[%s --> %s] ", to_timestamp(t0).c_str(), to_timestamp(t1).c_str());
|
||||
for (int j = 0; j < whisper_full_n_tokens(ctx, i); ++j) {
|
||||
const char * text = whisper_full_get_token_text(ctx, i, j);
|
||||
const float p = whisper_full_get_token_p (ctx, i, j);
|
||||
|
||||
const int col = std::max(0, std::min((int) k_colors.size(), (int) (std::pow(p, 3)*float(k_colors.size()))));
|
||||
|
||||
printf("%s%s%s", k_colors[col].c_str(), text, "\033[0m");
|
||||
}
|
||||
printf("\n");
|
||||
} else {
|
||||
const char * text = whisper_full_get_segment_text(ctx, i);
|
||||
|
||||
printf("[%s --> %s] %s\n", to_timestamp(t0).c_str(), to_timestamp(t1).c_str(), text);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -260,7 +290,6 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// output to text file
|
||||
if (params.output_txt) {
|
||||
|
||||
const auto fname_txt = fname_inp + ".txt";
|
||||
std::ofstream fout_txt(fname_txt);
|
||||
if (!fout_txt.is_open()) {
|
||||
@@ -279,7 +308,6 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// output to VTT file
|
||||
if (params.output_vtt) {
|
||||
|
||||
const auto fname_vtt = fname_inp + ".vtt";
|
||||
std::ofstream fout_vtt(fname_vtt);
|
||||
if (!fout_vtt.is_open()) {
|
||||
@@ -304,7 +332,6 @@ int main(int argc, char ** argv) {
|
||||
|
||||
// output to SRT file
|
||||
if (params.output_srt) {
|
||||
|
||||
const auto fname_srt = fname_inp + ".srt";
|
||||
std::ofstream fout_srt(fname_srt);
|
||||
if (!fout_srt.is_open()) {
|
||||
|
||||
Reference in New Issue
Block a user