Skip to content

Commit 784fe60

Browse files
authored
Revert "Remove model_name param from Whisper-Metal (#15798)"
This reverts commit 6ca1db6.
1 parent 3826f44 commit 784fe60

File tree

1 file changed

+19
-4
lines changed

1 file changed

+19
-4
lines changed

examples/models/whisper/main.cpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ DEFINE_string(
3939
audio_path,
4040
"",
4141
"Path to input audio file. Accepts .wav or raw float .bin.");
42+
DEFINE_string(
43+
model_name,
44+
"base",
45+
"Whisper model name (base, small, medium, large, large-v2, large-v3, large-v3-turbo).");
4246
DEFINE_double(
4347
temperature,
4448
0.0,
@@ -110,10 +114,21 @@ int main(int argc, char** argv) {
110114
config.max_new_tokens = FLAGS_max_new_tokens;
111115
config.temperature = static_cast<float>(FLAGS_temperature);
112116

113-
// All Whisper models from HuggingFace now use the v3 tokenizer format
114-
// where token 50257 = <|endoftext|> and token 50258 = <|startoftranscript|>
115-
config.decoder_start_token_id = 50258;
116-
ET_LOG(Info, "Using decoder_start_token_id=50258");
117+
// Set decoder_start_token_id based on model version
118+
if (FLAGS_model_name == "large-v2" || FLAGS_model_name == "large-v3" ||
119+
FLAGS_model_name == "large-v3-turbo") {
120+
config.decoder_start_token_id = 50258;
121+
ET_LOG(
122+
Info,
123+
"Using decoder_start_token_id=50258 for model: %s",
124+
FLAGS_model_name.c_str());
125+
} else {
126+
config.decoder_start_token_id = 50257;
127+
ET_LOG(
128+
Info,
129+
"Using decoder_start_token_id=50257 for model: %s",
130+
FLAGS_model_name.c_str());
131+
}
117132

118133
auto result =
119134
runner.transcribe(features, config, [&](const std::string& piece) {

0 commit comments

Comments
 (0)