diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 5bca8230b9b..e6ec3054daf 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -2656,14 +2656,18 @@ llm_graph_input_attn_k_dsa * llm_graph_context::build_attn_inp_k_dsa() const { inp->self_k_idxs_mla = mctx_cur->get_mla()->build_input_k_idxs(ctx0, ubatch); inp->self_kq_mask_mla = build_attn_inp_kq_mask(ctx0, mctx_cur->get_mla(), ubatch, cparams); - inp->self_kq_mask_mla_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_mla, GGML_TYPE_F16) : inp->self_kq_mask_mla; + inp->self_kq_mask_mla_cnv = inp->self_kq_mask_mla; } { inp->self_k_idxs_lid = mctx_cur->get_lid()->build_input_k_idxs(ctx0, ubatch); - inp->self_kq_mask_lid = build_attn_inp_kq_mask(ctx0, mctx_cur->get_lid(), ubatch, cparams); - inp->self_kq_mask_lid_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_lid, GGML_TYPE_F16) : inp->self_kq_mask_lid; + // ensure F32 mask + auto cparams_copy = cparams; + cparams_copy.flash_attn = false; + + inp->self_kq_mask_lid = build_attn_inp_kq_mask(ctx0, mctx_cur->get_lid(), ubatch, cparams_copy); + inp->self_kq_mask_lid_cnv = inp->self_kq_mask_lid; inp->self_k_rot_lid = mctx_cur->get_lid()->build_input_k_rot(ctx0); } diff --git a/src/llama-graph.h b/src/llama-graph.h index d07a084a8d6..eab82bd0d70 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -399,10 +399,10 @@ class llm_graph_input_attn_k_dsa : public llm_graph_input_i { ggml_tensor * self_k_idxs_mla = nullptr; // I64 [n_batch] ggml_tensor * self_k_idxs_lid = nullptr; // I64 [n_batch] - ggml_tensor * self_kq_mask_mla = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_mla_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_lid = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] - ggml_tensor * self_kq_mask_lid_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_mla = nullptr; // F32/F16 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_mla_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_lid = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream] + ggml_tensor * self_kq_mask_lid_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream] ggml_tensor * self_k_rot_lid = nullptr;