mirror of
https://github.com/ggerganov/whisper.cpp.git
synced 2023-11-04 02:52:44 +03:00
Adding sanitizer tests
This commit is contained in:
11
whisper.cpp
11
whisper.cpp
@@ -950,6 +950,7 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
|
||||
|
||||
// load weights
|
||||
{
|
||||
int n_loaded = 0;
|
||||
size_t total_size = 0;
|
||||
|
||||
while (true) {
|
||||
@@ -1004,9 +1005,17 @@ bool whisper_model_load(const std::string & fname, whisper_context & wctx) {
|
||||
|
||||
//printf("%24s - [%5d, %5d], type = %6s, %6.2f MB\n", name.data(), ne[0], ne[1], ftype == 0 ? "float" : "f16", ggml_nbytes(tensor)/1024.0/1024.0);
|
||||
total_size += ggml_nbytes(tensor);
|
||||
n_loaded++;
|
||||
}
|
||||
|
||||
printf("%s: model size = %8.2f MB\n", __func__, total_size/1024.0/1024.0);
|
||||
|
||||
if (n_loaded == 0) {
|
||||
printf("%s: WARN no tensors loaded from model file - assuming empty model for testing\n", __func__);
|
||||
} else if (n_loaded != model.tensors.size()) {
|
||||
fprintf(stderr, "%s: ERROR not all tensors loaded from model file - expected %zu, got %d\n", __func__, model.tensors.size(), n_loaded);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
fin.close();
|
||||
@@ -1772,8 +1781,6 @@ bool whisper_decode(
|
||||
}
|
||||
|
||||
// the most basic sampling scheme - select the top token
|
||||
// TODO: beam search
|
||||
// TODO: temperature
|
||||
whisper_vocab::id whisper_sample_best(
|
||||
const whisper_vocab & vocab,
|
||||
const float * probs, bool need_timestamp) {
|
||||
|
||||
Reference in New Issue
Block a user