diff --git a/models/convert-h5-to-ggml.py b/models/convert-h5-to-ggml.py index 80244d735e9..9f004d9bce5 100644 --- a/models/convert-h5-to-ggml.py +++ b/models/convert-h5-to-ggml.py @@ -107,6 +107,8 @@ def bytes_to_unicode(): fname_out = dir_out / "ggml-model.bin" tokens = json.load(open(dir_tokenizer / "vocab.json", "r", encoding="utf8")) +if "<|endoftext|>" in tokens: + del tokens["<|endoftext|>"] # use 16-bit or 32-bit floats use_f16 = True diff --git a/src/whisper.cpp b/src/whisper.cpp index f6793cb237b..d88e0ad4d9c 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -453,7 +453,7 @@ struct whisper_vocab { } int num_languages() const { - return n_vocab - 51765 - (is_multilingual() ? 1 : 0); + return token_translate - token_sot - 1; } }; @@ -1621,22 +1621,19 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con //printf("%s: vocab[%d] = '%s'\n", __func__, i, word.c_str()); } - vocab.n_vocab = model.hparams.n_vocab; - if (vocab.is_multilingual()) { - vocab.token_eot++; - vocab.token_sot++; + vocab.n_vocab = model.hparams.n_vocab; // all tokens, including special tokens - // account for variable number of language tokens - const int dt = vocab.num_languages() - 98; - - vocab.token_translate += dt; - vocab.token_transcribe += dt; - vocab.token_solm += dt; - vocab.token_prev += dt; - vocab.token_nosp += dt; - vocab.token_not += dt; - vocab.token_beg += dt; - } + vocab.token_eot = n_vocab; // <|endoftext|> 50256 for en, 50257 for multilingual, others for custom model + vocab.token_sot = n_vocab + 1; // <|startoftext|> + // [n_vocab + 2, vocab.n_vocab - 1507) are language tokens + // num_language = vocab.token_translate - vocab.token_sot - 1 = vocab.n_vocab - n_vocab - 1509 + vocab.token_translate = vocab.n_vocab - 1507; // <|translate|> + vocab.token_transcribe = vocab.n_vocab - 1506; // <|transcribe|> + vocab.token_solm = vocab.n_vocab - 1505; // <|startoflm|> + vocab.token_prev = vocab.n_vocab - 1504; // <|startofprev|> + vocab.token_nosp = vocab.n_vocab - 1503; // <|nospeech|> + vocab.token_not = vocab.n_vocab - 1502; // <|notimestamps|> + vocab.token_beg = vocab.n_vocab - 1501; // timestamps from <|0.00|> to <|30.00|>, 1501 tokens if (n_vocab < model.hparams.n_vocab) { WHISPER_LOG_INFO("%s: adding %d extra tokens\n", __func__, model.hparams.n_vocab - n_vocab);