Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions ggml/src/ggml-cuda/fattn-common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -778,12 +778,15 @@ void launch_fattn(
) {
constexpr int ncols = ncols1 * ncols2;

const bool is_mla = DV == 512; // TODO better parameterization

const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];

// TODO: make this more generic by removing the notion of "MLA".
// for example "is V a view of K?" so we can skip loading it.
// V strides should be driven by V itself and avoid assumption of the data layout
const bool is_mla = V->op == GGML_OP_VIEW && V->src[0] == K;

Comment on lines +785 to +789
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JohannesGaessler Let me know if this is clear. The proposal is that when V = ggml_view(K), the implementation can use this information, for example to avoid extra loads of data, etc. But technically, this is completely optional to do and the implementation can also just read from V directly.

This way, if the user code insists on passing different V data, then that should also work. In that case V won't be a view of K so it will be treated as a regular V tensor.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should adjust the kernel selection logic in fattn.cu to only use the MMA kernel if this boolean is true. The MMA kernel has the MLA-specific optimization of re-using the K data that was previously loaded for calculation of KQ as V data when calculating VKQ. But the tile kernel simply loads the data from the K and V pointers with no re-use. So for now I would suggest that we condition the use of the MMA kernel on this boolean and use the tile kernel as a fallback.

We could in principle compile multiple template specializations but currently they would be unused for real models.

GGML_ASSERT(V || is_mla);

const ggml_tensor * mask = dst->src[3];
Expand Down
7 changes: 4 additions & 3 deletions ggml/src/ggml-cuda/fattn-mma-f16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
// For MLA K and V have the same data.
// Therefore, iterate over V in reverse and re-use the data if possible.
static_assert(!mla || nstages <= 1, "combination of MLA and multi-stage loading not implemented");
constexpr int reusable_cutoff = mla ? (DKQ - 1) - (DKQ - 1) % (2*nbatch_K2) - (DKQ - DV) : DV;
// constexpr int reusable_cutoff = mla ? (DV - 1) - (DV - 1) % (2*nbatch_K2) : DV;
constexpr int reusable_cutoff = DV; // TODO implement properly
#if defined(AMD_WMMA_AVAILABLE) && !defined(LDMATRIX_TRANS_AVAILABLE)
T_A_VKQ A_identity;
make_identity_mat(A_identity);
Expand Down Expand Up @@ -1552,7 +1553,7 @@ static __global__ void flash_attn_ext_f16(
(const half *) (mask + nb33*(sequence % ne33));
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);

const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
const half2 * V_h2 = mla ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;

const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
Expand Down Expand Up @@ -1596,7 +1597,7 @@ static __global__ void flash_attn_ext_f16(
(const half *) (mask + nb33*(sequence % ne33));
float2 * dstk = ((float2 *) dst) + (sequence*ne01.z*ne02 + head0) * (DV/2);

const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
const half2 * V_h2 = mla ? K_h2 : (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio));
const float * sinks_f = sinks ? (const float *) sinks + head0 : nullptr;

const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f;
Expand Down
5 changes: 5 additions & 0 deletions src/llama-graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1565,6 +1565,11 @@ ggml_tensor * llm_graph_context::build_attn_mha(
v = ggml_transpose(ctx0, v);
}

// TODO: update llama_kv_cache to not store V cache in the MLA case and automatically return a view of K
if (v_mla) {
v = ggml_view_4d(ctx0, k, v->ne[0], v->ne[1], v->ne[2], v->ne[3], k->nb[1], k->nb[2], k->nb[3], 0);
}

// this can happen when KV cache is not used (e.g. an embedding model with non-causal attn)
if (k->type == GGML_TYPE_F32) {
k = ggml_cast(ctx0, k, GGML_TYPE_F16);
Expand Down
8 changes: 6 additions & 2 deletions src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1594,6 +1594,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
const auto & n_embd_head_k = hparams.n_embd_head_k;
//const auto & n_embd_head_v = hparams.n_embd_head_v;

const auto & n_rot = hparams.n_rot;

const auto n_embd_nope = hparams.n_lora_kv > 0 ? n_embd_head_k - n_rot : 0;

auto inp = std::make_unique<llm_graph_input_k_shift>(this);

inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, (int64_t) get_size()*n_stream);
Expand All @@ -1614,10 +1618,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co

ggml_tensor * k =
ggml_view_3d(ctx, layer.k,
n_embd_head_k, n_head_kv, get_size()*n_stream,
n_rot, n_head_kv, get_size()*n_stream,
ggml_row_size(layer.k->type, n_embd_head_k),
ggml_row_size(layer.k->type, n_embd_k_gqa),
0);
ggml_row_size(layer.k->type, n_embd_nope));

ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);

Expand Down
9 changes: 4 additions & 5 deletions src/models/deepseek2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,14 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr

// {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
// note: rope must go first for in-place context shifting in build_rope_shift()
ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0);
ggml_tensor * Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0);
Copy link
Contributor

@ngxson ngxson Jan 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a side-effect of this change is that build_rope_shift will no longer work correctly (may need to use a view when shifting)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should fix it:

diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp
index fd9f97d52..2cc9efed8 100644
--- a/src/llama-kv-cache.cpp
+++ b/src/llama-kv-cache.cpp
@@ -1614,10 +1614,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
 
         ggml_tensor * k =
             ggml_view_3d(ctx, layer.k,
-                n_embd_head_k, n_head_kv, get_size()*n_stream,
+                n_embd_head_k - hparams.n_lora_kv, n_head_kv, get_size()*n_stream,
                 ggml_row_size(layer.k->type, n_embd_head_k),
                 ggml_row_size(layer.k->type, n_embd_k_gqa),
-                0);
+                ggml_row_size(layer.k->type, hparams.n_lora_kv));
 
         ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
 

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The rope shift logic for PLM and minicpm3 has been broken since it didn't take into account the "nope" portion of the K embeddings. This is fixed now with 69d4fd7

cb(Qcur, "Qcur", il);

kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens);
cb(kv_cmpr, "kv_cmpr_reshape", il);

// {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens}
ggml_tensor * Kcur = ggml_concat(ctx0, k_pe, kv_cmpr, 0);
ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0);
cb(Kcur, "Kcur", il);

// {kv_lora_rank, 1, n_tokens}
Expand Down Expand Up @@ -169,11 +169,10 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr
Vcur = ggml_cont(ctx0, Vcur);
cb(Vcur, "Vcur_cont", il);

// note: rope must go first for in-place context shifting in build_rope_shift()
ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope, 0);
ggml_tensor * Qcur = ggml_concat(ctx0, q_nope, q_pe, 0);
cb(Qcur, "Qcur", il);

ggml_tensor * Kcur = ggml_concat(ctx0, ggml_repeat(ctx0, k_pe, q_pe), k_nope, 0);
ggml_tensor * Kcur = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0);
cb(Kcur, "Kcur", il);

if (inp_attn_scale) {
Expand Down
1 change: 1 addition & 0 deletions src/models/minicpm3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ llm_build_minicpm3::llm_build_minicpm3(const llama_model & model, const llm_grap

const uint32_t n_embd_head_qk_rope = hparams.n_rot;
const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;

const uint32_t kv_lora_rank = hparams.n_lora_kv;

ggml_tensor * cur;
Expand Down
1 change: 1 addition & 0 deletions src/models/plm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ llm_build_plm::llm_build_plm(const llama_model & model, const llm_graph_params &

const uint32_t n_embd_head_qk_rope = hparams.n_rot;
const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;

const uint32_t kv_lora_rank = hparams.n_lora_kv;

ggml_tensor * cur;
Expand Down
14 changes: 13 additions & 1 deletion tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6122,7 +6122,19 @@ struct test_flash_attn_ext : public test_case {
ggml_tensor * k = create_permuted(type_KV, hsk_padded, kv, nh, nr23[1], true); // the K tensor is usually a view of the K cache
ggml_set_name(k, "k");

ggml_tensor * v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache
ggml_tensor * v = nullptr;
if (hsk_padded == 576 && hsv_padded == 512) {
// TODO: this branch should become a separate test case parameter instead of hardcoding this for these head shapes

// in this branch, the V cache is sub-view of the K cache. this is used by some MLA-based models
// for more info:
// - https://github.com/ggml-org/llama.cpp/pull/13435
// - https://github.com/ggml-org/llama.cpp/pull/18953#issuecomment-3774948392
// - https://github.com/ggml-org/llama.cpp/pull/18986
v = ggml_view_4d(ctx, k, hsv_padded, kv, nh, nr23[1], k->nb[1], k->nb[2], k->nb[3], 0);
} else {
v = create_permuted(type_KV, hsv_padded, kv, nh, nr23[1], true); // the V tensor is usually a view of the V cache
}
ggml_set_name(v, "v");

ggml_tensor * m = nullptr;
Expand Down
Loading