From f361cdd070ff0f51f65040c0ecae65d7740396e8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 21 Jan 2026 10:17:18 +0200 Subject: [PATCH 1/5] mla : pass V as a view of K to the FA op --- ggml/src/ggml-cuda/fattn-common.cuh | 7 +++++-- src/llama-graph.cpp | 5 +++++ src/models/deepseek2.cpp | 4 ++-- tests/test-backend-ops.cpp | 19 ++++++++++++++++++- 4 files changed, 30 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 8468ba8488d..a781fb91f5b 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -778,12 +778,15 @@ void launch_fattn( ) { constexpr int ncols = ncols1 * ncols2; - const bool is_mla = DV == 512; // TODO better parameterization - const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; + // TODO: make this more generic by removing the notion of "MLA". + // for example "is V a view of K?" so we can skip loading it. + // V strides should be driven by V itself and avoid assumption of the data layout + const bool is_mla = V->op == GGML_OP_VIEW && V->src[0] == K; + GGML_ASSERT(V || is_mla); const ggml_tensor * mask = dst->src[3]; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 57485c534ee..5ebd0cf8aac 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1565,6 +1565,11 @@ ggml_tensor * llm_graph_context::build_attn_mha( v = ggml_transpose(ctx0, v); } + // TODO: update llama_kv_cache to not store V cache in the MLA case and automatically return a view of K + if (v_mla) { + v = ggml_view_4d(ctx0, k, v->ne[0], v->ne[1], v->ne[2], v->ne[3], k->nb[1], k->nb[2], k->nb[3], 0); + } + // this can happen when KV cache is not used (e.g. an embedding model with non-causal attn) if (k->type == GGML_TYPE_F32) { k = ggml_cast(ctx0, k, GGML_TYPE_F16); diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp index ca63a62ad1b..4f4c7ac2922 100644 --- a/src/models/deepseek2.cpp +++ b/src/models/deepseek2.cpp @@ -124,14 +124,14 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens} // note: rope must go first for in-place context shifting in build_rope_shift() - ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0); + ggml_tensor * Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0); cb(Qcur, "Qcur", il); kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens); cb(kv_cmpr, "kv_cmpr_reshape", il); // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens} - ggml_tensor * Kcur = ggml_concat(ctx0, k_pe, kv_cmpr, 0); + ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0); cb(Kcur, "Kcur", il); // {kv_lora_rank, 1, n_tokens} diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 9f61c6483de..66db0f2e2fc 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6122,7 +6122,24 @@ struct test_flash_attn_ext : public test_case { ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, nr23[1], true); // the K tensor is usually a view of the K cache ggml_set_name(k, "k"); - ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache + ggml_tensor * v = nullptr; + if (hsk_padded == 576 && hsv_padded == 512) { + // TODO: this branch should become a separate test case parameter instead of hardcoding this for these head shapes + + // in this branch, the V cache is sub-view of the K cache. this is used by some MLA-based models + // for more info: + // - https://github.com/ggml-org/llama.cpp/pull/13435 + // - https://github.com/ggml-org/llama.cpp/pull/18953#issuecomment-3774948392 + v = ggml_view_4d(ctx, k, hsv_padded, kv, nh, nr23[1], k->nb[1], k->nb[2], k->nb[3], 0); + + // note: this is the currently assumed layout by the CUDA FA implementation + // however, this layout is problematic, because the offset can become very inconveniet for quantized KV types (i.e. not multiple of 16) + // so the models/deepseek2.cpp implementation has been updated to concat the PE data at the end of the row, resulting in the layout above + // the CUDA implementation has to be updated to not make an assumption abou the layout, and instead use the V tensor as it is + //v = ggml_view_4d(ctx, k, hsv_padded, kv, nh, nr23[1], k->nb[1], k->nb[2], k->nb[3], ggml_row_size(type_KV, hsk_padded - hsv_padded)/ggml_blck_size(type_KV)); + } else { + v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache + } ggml_set_name(v, "v"); ggml_tensor * m = nullptr; From c6215d6f13f62187295ae8f1ac8683bdc619d585 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 21 Jan 2026 12:11:19 +0200 Subject: [PATCH 2/5] cuda : adjust mla logic to new layout --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 8cca89c2bfa..88a48547cb1 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -794,7 +794,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( // For MLA K and V have the same data. // Therefore, iterate over V in reverse and re-use the data if possible. static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented"); - constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV; + constexpr int reusable_cutoff = mla ? (DV - 1) - (DV - 1) % (2*nbatch_K2) : DV; #if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE) T_A_VKQ A_identity; make_identity_mat(A_identity); @@ -1552,7 +1552,7 @@ static __global__ void flash_attn_ext_f16( (const half *) (mask + nb33*(sequence % ne33)); float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); + const half2 * V_h2 = mla ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; @@ -1596,7 +1596,7 @@ static __global__ void flash_attn_ext_f16( (const half *) (mask + nb33*(sequence % ne33)); float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2); - const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); + const half2 * V_h2 = mla ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr; const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; From 69d4fd7e8fd8012a60f1e76dc412c71c7a2d2517 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 22 Jan 2026 14:49:43 +0200 Subject: [PATCH 3/5] kv-cache : fix rope shift --- src/llama-kv-cache.cpp | 8 ++++++-- src/models/deepseek2.cpp | 5 ++--- src/models/minicpm3.cpp | 1 + src/models/plm.cpp | 1 + 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index fd9f97d52e8..a7327c49874 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1594,6 +1594,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co const auto & n_embd_head_k = hparams.n_embd_head_k; //const auto & n_embd_head_v = hparams.n_embd_head_v; + const auto & n_rot = hparams.n_rot; + + const auto n_embd_nope = hparams.n_lora_kv > 0 ? n_embd_head_k - n_rot : 0; + auto inp = std::make_unique(this); inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream); @@ -1614,10 +1618,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co ggml_tensor * k = ggml_view_3d(ctx, layer.k, - n_embd_head_k, n_head_kv, get_size()*n_stream, + n_rot, n_head_kv, get_size()*n_stream, ggml_row_size(layer.k->type, n_embd_head_k), ggml_row_size(layer.k->type, n_embd_k_gqa), - 0); + ggml_row_size(layer.k->type, n_embd_nope)); ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l); diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp index 4f4c7ac2922..c404c1946d0 100644 --- a/src/models/deepseek2.cpp +++ b/src/models/deepseek2.cpp @@ -169,11 +169,10 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr Vcur = ggml_cont(ctx0, Vcur); cb(Vcur, "Vcur_cont", il); - // note: rope must go first for in-place context shifting in build_rope_shift() - ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope, 0); + ggml_tensor * Qcur = ggml_concat(ctx0, q_nope, q_pe, 0); cb(Qcur, "Qcur", il); - ggml_tensor * Kcur = ggml_concat(ctx0, ggml_repeat(ctx0, k_pe, q_pe), k_nope, 0); + ggml_tensor * Kcur = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); cb(Kcur, "Kcur", il); if (inp_attn_scale) { diff --git a/src/models/minicpm3.cpp b/src/models/minicpm3.cpp index f374a9fd030..297cc34ba58 100644 --- a/src/models/minicpm3.cpp +++ b/src/models/minicpm3.cpp @@ -9,6 +9,7 @@ llm_build_minicpm3::llm_build_minicpm3(const llama_model & model, const llm_grap const uint32_t n_embd_head_qk_rope = hparams.n_rot; const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; ggml_tensor * cur; diff --git a/src/models/plm.cpp b/src/models/plm.cpp index 481cbba6907..612a487c564 100644 --- a/src/models/plm.cpp +++ b/src/models/plm.cpp @@ -5,6 +5,7 @@ llm_build_plm::llm_build_plm(const llama_model & model, const llm_graph_params & const uint32_t n_embd_head_qk_rope = hparams.n_rot; const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot; + const uint32_t kv_lora_rank = hparams.n_lora_kv; ggml_tensor * cur; From e2a230a5d64d5a32d5f98fbb992d1dbeeff0eb6d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 22 Jan 2026 15:52:08 +0200 Subject: [PATCH 4/5] tests : remove comment --- tests/test-backend-ops.cpp | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 66db0f2e2fc..146d05f53bc 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -6130,13 +6130,8 @@ struct test_flash_attn_ext : public test_case { // for more info: // - https://github.com/ggml-org/llama.cpp/pull/13435 // - https://github.com/ggml-org/llama.cpp/pull/18953#issuecomment-3774948392 + // - https://github.com/ggml-org/llama.cpp/pull/18986 v = ggml_view_4d(ctx, k, hsv_padded, kv, nh, nr23[1], k->nb[1], k->nb[2], k->nb[3], 0); - - // note: this is the currently assumed layout by the CUDA FA implementation - // however, this layout is problematic, because the offset can become very inconveniet for quantized KV types (i.e. not multiple of 16) - // so the models/deepseek2.cpp implementation has been updated to concat the PE data at the end of the row, resulting in the layout above - // the CUDA implementation has to be updated to not make an assumption abou the layout, and instead use the V tensor as it is - //v = ggml_view_4d(ctx, k, hsv_padded, kv, nh, nr23[1], k->nb[1], k->nb[2], k->nb[3], ggml_row_size(type_KV, hsk_padded - hsv_padded)/ggml_blck_size(type_KV)); } else { v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache } From fa8213dddcb5fa9f2ebb2ce6d9e334ec1ed55ffe Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 22 Jan 2026 16:13:48 +0200 Subject: [PATCH 5/5] cuda : fix reusable_cutoff MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Johannes Gäßler --- ggml/src/ggml-cuda/fattn-mma-f16.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 88a48547cb1..203569e3459 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -794,7 +794,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( // For MLA K and V have the same data. // Therefore, iterate over V in reverse and re-use the data if possible. static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented"); - constexpr int reusable_cutoff = mla ? (DV - 1) - (DV - 1) % (2*nbatch_K2) : DV; + // constexpr int reusable_cutoff = mla ? (DV - 1) - (DV - 1) % (2*nbatch_K2) : DV; + constexpr int reusable_cutoff = DV; // TODO implement properly #if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE) T_A_VKQ A_identity; make_identity_mat(A_identity);