Skip to content

Commit d58dee8

Browse files
saood06sszymczyIwan Kawrakow
authored
Deepseek MLA Optimizations V2 (#195)
* Avoid allocating MHA KV cache when MLA is turned on * Added missing gguf-py file * Added final optimizations Co-authored-by: Stanisław Szymczyk <[email protected]> * Make sure we do have wk_b and wv_b before enabling MLA --------- Co-authored-by: Stanisław Szymczyk <[email protected]> Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 3aaf602 commit d58dee8

File tree

2 files changed

+53
-21
lines changed

2 files changed

+53
-21
lines changed

gguf-py/gguf/tensor_mapping.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,14 @@ class TensorNameMap:
446446
"model.layers.{bid}.self_attn.kv_b_proj", # deepseek2
447447
),
448448

449+
MODEL_TENSOR.ATTN_K_B: (
450+
"model.layers.{bid}.self_attn.k_b_proj", # deepseek2
451+
),
452+
453+
MODEL_TENSOR.ATTN_V_B: (
454+
"model.layers.{bid}.self_attn.v_b_proj", # deepseek2
455+
),
456+
449457
MODEL_TENSOR.ATTN_Q_A_NORM: (
450458
"model.layers.{bid}.self_attn.q_a_layernorm", # deepseek2
451459
),

src/llama.cpp

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3173,8 +3173,17 @@ static bool llama_kv_cache_init(
31733173
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
31743174

31753175
struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
3176-
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
3177-
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
3176+
ggml_tensor * k;
3177+
ggml_tensor * v;
3178+
if (cparams.mla_attn && model.layers[i].wk_b && model.layers[i].wv_b) {
3179+
k = ggml_new_tensor_1d(ctx, type_k, 1);
3180+
v = ggml_new_tensor_1d(ctx, type_v, 1);
3181+
}
3182+
else {
3183+
k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
3184+
v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
3185+
}
3186+
31783187
ggml_format_name(k, "cache_k_l%d", i);
31793188
ggml_format_name(v, "cache_v_l%d", i);
31803189
cache.k_l.push_back(k);
@@ -13368,6 +13377,10 @@ struct llm_build_context {
1336813377
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
1336913378
struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
1337013379

13380+
// whether to use n_tokens as the matrix dimension during multiplication or n_head
13381+
// n_tokens is higher during prompt processing, this allows to optimize for this case
13382+
bool pp_opt = n_tokens > n_head;
13383+
1337113384
for (int il = 0; il < n_layer; ++il) {
1337213385
struct ggml_tensor * inpSA = inpL;
1337313386

@@ -13496,43 +13509,54 @@ struct llm_build_context {
1349613509
struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank * n_embd_head_qk_nope), 0);
1349713510
cb(wk_b, "wk_b", il);
1349813511

13499-
struct ggml_tensor * q_nope_perm = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
13500-
cb(q_nope_perm, "q_nope_perm", il);
13512+
q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
13513+
cb(q_nope, "q_nope_perm", il);
1350113514

13502-
struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope_perm);
13515+
struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope);
1350313516
cb(q_nope2, "q_nope2", il);
1350413517

13505-
struct ggml_tensor * q_nope2_perm = ggml_permute(ctx0, q_nope2, 0, 2, 1, 3);
13506-
cb(q_nope2_perm, "q_nope2_perm", il);
13507-
13508-
struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2_perm);
13518+
if (!pp_opt) {
13519+
q_nope2 = ggml_permute(ctx0, q_nope2, 0, 2, 1, 3);
13520+
cb(q_nope2, "q_nope2_perm", il);
13521+
}
13522+
struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2);
1350913523
cb(kq_nope, "kq_nope", il);
1351013524

13511-
// Huh? This is not used anywhere
13512-
//struct ggml_tensor * q_pe_perm = ggml_permute(ctx0, q_pe, 0, 3, 2, 1);
13513-
//cb(q_pe_perm, "q_pe_perm", il);
13525+
if (!pp_opt) {
13526+
kq_nope = ggml_permute(ctx0, kq_nope, 0, 2, 1, 3);
13527+
cb(kq_nope, "kq_nope_perm", il);
13528+
}
1351413529

13530+
if (pp_opt) {
13531+
q_pe = ggml_permute(ctx0, q_pe, 0, 2, 1, 3);
13532+
cb(q_pe, "q_pe_perm", il);
13533+
}
1351513534
struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe);
1351613535
cb(kq_pe, "kq_pe", il);
1351713536

13537+
if (!pp_opt) {
13538+
kq_pe = ggml_permute(ctx0, kq_pe, 0, 2, 1, 3);
13539+
cb(kq_pe, "kq_pe_perm", il);
13540+
}
13541+
1351813542
struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_pe);
1351913543
cb(kq, "kq", il);
1352013544

13521-
// We need this copy because soft_max expects a contiguous tensor
13522-
kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3));
13523-
cb(kq, "kq_perm", il);
13524-
1352513545
kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
1352613546
cb(kq, "kq_soft_max_ext", il);
1352713547

13528-
struct ggml_tensor * kq_perm = ggml_permute(ctx0, kq, 0, 2, 1, 3);
13529-
cb(kq_perm, "kq_soft_max_ext_perm", il);
13548+
if (!pp_opt) {
13549+
kq = ggml_permute(ctx0, kq, 0, 2, 1, 3);
13550+
cb(kq, "kq_soft_max_ext_perm", il);
13551+
}
1353013552

13531-
struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq_perm);
13553+
struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq);
1353213554
cb(kqv_compressed, "kqv_compressed", il);
1353313555

13534-
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
13535-
cb(kqv_compressed, "kqv_compressed_perm", il);
13556+
if (!pp_opt) {
13557+
kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
13558+
cb(kqv_compressed, "kqv_compressed_perm", il);
13559+
}
1353613560

1353713561
struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank * n_embd_head_v), 0);
1353813562
cb(wv_b, "wv_b", il);

0 commit comments

Comments
 (0)