-
Notifications
You must be signed in to change notification settings - Fork 15.6k
mla : make the V tensor a view of K #18986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f361cdd
c6215d6
69d4fd7
e2a230a
fa8213d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -124,14 +124,14 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr | |
|
|
||
| // {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens} | ||
| // note: rope must go first for in-place context shifting in build_rope_shift() | ||
| ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope_absorbed, 0); | ||
| ggml_tensor * Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. a side-effect of this change is that
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should fix it: diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp
index fd9f97d52..2cc9efed8 100644
--- a/src/llama-kv-cache.cpp
+++ b/src/llama-kv-cache.cpp
@@ -1614,10 +1614,10 @@ ggml_cgraph * llama_kv_cache::build_graph_shift(llm_graph_result * res, llama_co
ggml_tensor * k =
ggml_view_3d(ctx, layer.k,
- n_embd_head_k, n_head_kv, get_size()*n_stream,
+ n_embd_head_k - hparams.n_lora_kv, n_head_kv, get_size()*n_stream,
ggml_row_size(layer.k->type, n_embd_head_k),
ggml_row_size(layer.k->type, n_embd_k_gqa),
- 0);
+ ggml_row_size(layer.k->type, hparams.n_lora_kv));
ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l);
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The rope shift logic for PLM and minicpm3 has been broken since it didn't take into account the "nope" portion of the K embeddings. This is fixed now with 69d4fd7 |
||
| cb(Qcur, "Qcur", il); | ||
|
|
||
| kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens); | ||
| cb(kv_cmpr, "kv_cmpr_reshape", il); | ||
|
|
||
| // {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens} | ||
| ggml_tensor * Kcur = ggml_concat(ctx0, k_pe, kv_cmpr, 0); | ||
| ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0); | ||
| cb(Kcur, "Kcur", il); | ||
|
|
||
| // {kv_lora_rank, 1, n_tokens} | ||
|
|
@@ -169,11 +169,10 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr | |
| Vcur = ggml_cont(ctx0, Vcur); | ||
| cb(Vcur, "Vcur_cont", il); | ||
|
|
||
| // note: rope must go first for in-place context shifting in build_rope_shift() | ||
| ggml_tensor * Qcur = ggml_concat(ctx0, q_pe, q_nope, 0); | ||
| ggml_tensor * Qcur = ggml_concat(ctx0, q_nope, q_pe, 0); | ||
| cb(Qcur, "Qcur", il); | ||
|
|
||
| ggml_tensor * Kcur = ggml_concat(ctx0, ggml_repeat(ctx0, k_pe, q_pe), k_nope, 0); | ||
| ggml_tensor * Kcur = ggml_concat(ctx0, k_nope, ggml_repeat(ctx0, k_pe, q_pe), 0); | ||
| cb(Kcur, "Kcur", il); | ||
|
|
||
| if (inp_attn_scale) { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JohannesGaessler Let me know if this is clear. The proposal is that when
V = ggml_view(K), the implementation can use this information, for example to avoid extra loads of data, etc. But technically, this is completely optional to do and the implementation can also just read from V directly.This way, if the user code insists on passing different
Vdata, then that should also work. In that caseVwon't be a view ofKso it will be treated as a regularVtensor.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should adjust the kernel selection logic in
fattn.cuto only use the MMA kernel if this boolean is true. The MMA kernel has the MLA-specific optimization of re-using the K data that was previously loaded for calculation of KQ as V data when calculating VKQ. But the tile kernel simply loads the data from the K and V pointers with no re-use. So for now I would suggest that we condition the use of the MMA kernel on this boolean and use the tile kernel as a fallback.We could in principle compile multiple template specializations but currently they would be unused for real models.