diff --git a/src/whisper.cpp b/src/whisper.cpp index c4f912a9a4d..eb5cd8f3947 100644 --- a/src/whisper.cpp +++ b/src/whisper.cpp @@ -600,6 +600,9 @@ struct whisper_hparams { int32_t n_mels = 80; int32_t ftype = 1; float eps = 1e-5f; + int32_t n_audio_conv1_kernel = 3; + int32_t n_audio_window_size = 0; + int32_t n_audio_last_window_layer = -1; }; // audio encoding layer @@ -1517,6 +1520,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con read_safe(loader, hparams.n_text_layer); read_safe(loader, hparams.n_mels); read_safe(loader, hparams.ftype); + read_safe(loader, hparams.n_audio_conv1_kernel); + read_safe(loader, hparams.n_audio_window_size); + read_safe(loader, hparams.n_audio_last_window_layer); assert(hparams.n_text_state == hparams.n_audio_state); @@ -1757,7 +1763,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con // encoder model.e_pe = create_tensor(ASR_TENSOR_ENC_POS_EMBD, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx)); - model.e_conv_1_w = create_tensor(ASR_TENSOR_CONV1_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state)); + model.e_conv_1_w = create_tensor(ASR_TENSOR_CONV1_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_3d(ctx, vtype, hparams.n_audio_conv1_kernel, n_mels, n_audio_state)); model.e_conv_1_b = create_tensor(ASR_TENSOR_CONV1_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state)); model.e_conv_2_w = create_tensor(ASR_TENSOR_CONV2_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state)); @@ -2095,6 +2101,15 @@ static struct ggml_cgraph * whisper_build_graph_encoder( struct ggml_tensor * inpL = cur; + struct ggml_tensor * window_mask = nullptr; + const int window_size = hparams.n_audio_window_size; + const int last_window_layer = hparams.n_audio_last_window_layer; + if (window_size > 0 && last_window_layer >= 0) { + window_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_ctx, n_ctx, 1); + ggml_set_name(window_mask, "window_mask"); + ggml_set_input(window_mask); + } + for (int il = 0; il < n_layer; ++il) { const auto & layer = model.layers_encoder[il]; @@ -2156,7 +2171,9 @@ static struct ggml_cgraph * whisper_build_graph_encoder( ggml_element_size(kv_pad.v)*n_state_head, 0); - cur = ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f, 0.0f); + struct ggml_tensor * attn_mask_fa = (window_mask && il <= last_window_layer) + ? ggml_cast(ctx0, window_mask, GGML_TYPE_F16) : nullptr; + cur = ggml_flash_attn_ext(ctx0, Q, K, V, attn_mask_fa, KQscale, 0.0f, 0.0f); cur = ggml_reshape_2d(ctx0, cur, n_state, n_ctx); } else { @@ -2170,7 +2187,8 @@ static struct ggml_cgraph * whisper_build_graph_encoder( // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f); + struct ggml_tensor * enc_attn_mask = (window_mask && il <= last_window_layer) ? window_mask : nullptr; + struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, enc_attn_mask, KQscale, 0.0f); struct ggml_tensor * V = ggml_cast(ctx0, @@ -2444,6 +2462,25 @@ static bool whisper_encode_internal( return false; } + { + struct ggml_tensor * wmask = ggml_graph_get_tensor(gf, "window_mask"); + if (wmask) { + const int n_ctx = wstate.exp_n_audio_ctx > 0 + ? wstate.exp_n_audio_ctx : wctx.model.hparams.n_audio_ctx; + const int ws = wctx.model.hparams.n_audio_window_size; + const int half_w = ws / 2; + std::vector mask_data(n_ctx * n_ctx); + for (int i = 0; i < n_ctx; ++i) { + for (int j = 0; j < n_ctx; ++j) { + mask_data[i * n_ctx + j] = + (abs(i - j) <= half_w) ? 0.0f : -INFINITY; + } + } + ggml_backend_tensor_set(wmask, mask_data.data(), 0, + n_ctx * n_ctx * sizeof(float)); + } + } + if (!ggml_graph_compute_helper(sched, gf, n_threads)) { return false; } @@ -6958,6 +6995,11 @@ int whisper_full_with_state( } else { prompt_init.push_back(whisper_token_transcribe(ctx)); } + } else if (ctx->model.hparams.n_audio_window_size > 0) { + const int lang_id = whisper_lang_id(params.language); + state->lang_id = lang_id; + prompt_init.push_back(whisper_token_lang(ctx, lang_id)); + prompt_init.push_back(whisper_token_transcribe(ctx)); } // first release distilled models require the "no_timestamps" token