Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 45 additions & 3 deletions src/whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,9 @@ struct whisper_hparams {
int32_t n_mels = 80;
int32_t ftype = 1;
float eps = 1e-5f;
int32_t n_audio_conv1_kernel = 3;
int32_t n_audio_window_size = 0;
int32_t n_audio_last_window_layer = -1;
};

// audio encoding layer
Expand Down Expand Up @@ -1517,6 +1520,9 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
read_safe(loader, hparams.n_text_layer);
read_safe(loader, hparams.n_mels);
read_safe(loader, hparams.ftype);
read_safe(loader, hparams.n_audio_conv1_kernel);
read_safe(loader, hparams.n_audio_window_size);
read_safe(loader, hparams.n_audio_last_window_layer);

assert(hparams.n_text_state == hparams.n_audio_state);

Expand Down Expand Up @@ -1757,7 +1763,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
// encoder
model.e_pe = create_tensor(ASR_TENSOR_ENC_POS_EMBD, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx));

model.e_conv_1_w = create_tensor(ASR_TENSOR_CONV1_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state));
model.e_conv_1_w = create_tensor(ASR_TENSOR_CONV1_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_3d(ctx, vtype, hparams.n_audio_conv1_kernel, n_mels, n_audio_state));
model.e_conv_1_b = create_tensor(ASR_TENSOR_CONV1_BIAS, ASR_SYSTEM_ENCODER, ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state));

model.e_conv_2_w = create_tensor(ASR_TENSOR_CONV2_WEIGHT, ASR_SYSTEM_ENCODER, ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state));
Expand Down Expand Up @@ -2095,6 +2101,15 @@ static struct ggml_cgraph * whisper_build_graph_encoder(

struct ggml_tensor * inpL = cur;

struct ggml_tensor * window_mask = nullptr;
const int window_size = hparams.n_audio_window_size;
const int last_window_layer = hparams.n_audio_last_window_layer;
if (window_size > 0 && last_window_layer >= 0) {
window_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_ctx, n_ctx, 1);
ggml_set_name(window_mask, "window_mask");
ggml_set_input(window_mask);
}

for (int il = 0; il < n_layer; ++il) {
const auto & layer = model.layers_encoder[il];

Expand Down Expand Up @@ -2156,7 +2171,9 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
ggml_element_size(kv_pad.v)*n_state_head,
0);

cur = ggml_flash_attn_ext(ctx0, Q, K, V, nullptr, KQscale, 0.0f, 0.0f);
struct ggml_tensor * attn_mask_fa = (window_mask && il <= last_window_layer)
? ggml_cast(ctx0, window_mask, GGML_TYPE_F16) : nullptr;
cur = ggml_flash_attn_ext(ctx0, Q, K, V, attn_mask_fa, KQscale, 0.0f, 0.0f);

cur = ggml_reshape_2d(ctx0, cur, n_state, n_ctx);
} else {
Expand All @@ -2170,7 +2187,8 @@ static struct ggml_cgraph * whisper_build_graph_encoder(
// K * Q
struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q);

struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, nullptr, KQscale, 0.0f);
struct ggml_tensor * enc_attn_mask = (window_mask && il <= last_window_layer) ? window_mask : nullptr;
struct ggml_tensor * KQ_soft_max = ggml_soft_max_ext(ctx0, KQ, enc_attn_mask, KQscale, 0.0f);

struct ggml_tensor * V =
ggml_cast(ctx0,
Expand Down Expand Up @@ -2444,6 +2462,25 @@ static bool whisper_encode_internal(
return false;
}

{
struct ggml_tensor * wmask = ggml_graph_get_tensor(gf, "window_mask");
if (wmask) {
const int n_ctx = wstate.exp_n_audio_ctx > 0
? wstate.exp_n_audio_ctx : wctx.model.hparams.n_audio_ctx;
const int ws = wctx.model.hparams.n_audio_window_size;
const int half_w = ws / 2;
std::vector<float> mask_data(n_ctx * n_ctx);
for (int i = 0; i < n_ctx; ++i) {
for (int j = 0; j < n_ctx; ++j) {
mask_data[i * n_ctx + j] =
(abs(i - j) <= half_w) ? 0.0f : -INFINITY;
}
}
ggml_backend_tensor_set(wmask, mask_data.data(), 0,
n_ctx * n_ctx * sizeof(float));
}
}

if (!ggml_graph_compute_helper(sched, gf, n_threads)) {
return false;
}
Expand Down Expand Up @@ -6958,6 +6995,11 @@ int whisper_full_with_state(
} else {
prompt_init.push_back(whisper_token_transcribe(ctx));
}
} else if (ctx->model.hparams.n_audio_window_size > 0) {
const int lang_id = whisper_lang_id(params.language);
state->lang_id = lang_id;
prompt_init.push_back(whisper_token_lang(ctx, lang_id));
prompt_init.push_back(whisper_token_transcribe(ctx));
}

// first release distilled models require the "no_timestamps" token
Expand Down
Loading