@@ -3173,8 +3173,17 @@ static bool llama_kv_cache_init(
31733173 const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
31743174
31753175 struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
3176- ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
3177- ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
3176+ ggml_tensor * k;
3177+ ggml_tensor * v;
3178+ if (cparams.mla_attn && model.layers[i].wk_b && model.layers[i].wv_b) {
3179+ k = ggml_new_tensor_1d(ctx, type_k, 1);
3180+ v = ggml_new_tensor_1d(ctx, type_v, 1);
3181+ }
3182+ else {
3183+ k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size);
3184+ v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size);
3185+ }
3186+
31783187 ggml_format_name(k, "cache_k_l%d", i);
31793188 ggml_format_name(v, "cache_v_l%d", i);
31803189 cache.k_l.push_back(k);
@@ -13368,6 +13377,10 @@ struct llm_build_context {
1336813377 // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
1336913378 struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
1337013379
13380+ // whether to use n_tokens as the matrix dimension during multiplication or n_head
13381+ // n_tokens is higher during prompt processing, this allows to optimize for this case
13382+ bool pp_opt = n_tokens > n_head;
13383+
1337113384 for (int il = 0; il < n_layer; ++il) {
1337213385 struct ggml_tensor * inpSA = inpL;
1337313386
@@ -13496,43 +13509,54 @@ struct llm_build_context {
1349613509 struct ggml_tensor * wk_b = ggml_view_3d(ctx0, model.layers[il].wk_b, n_embd_head_qk_nope, kv_lora_rank, n_head, ggml_row_size(model.layers[il].wk_b->type, n_embd_head_qk_nope), ggml_row_size(model.layers[il].wk_b->type, kv_lora_rank * n_embd_head_qk_nope), 0);
1349713510 cb(wk_b, "wk_b", il);
1349813511
13499- struct ggml_tensor * q_nope_perm = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
13500- cb(q_nope_perm , "q_nope_perm", il);
13512+ q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
13513+ cb(q_nope , "q_nope_perm", il);
1350113514
13502- struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope_perm );
13515+ struct ggml_tensor * q_nope2 = ggml_mul_mat(ctx0, wk_b, q_nope );
1350313516 cb(q_nope2, "q_nope2", il);
1350413517
13505- struct ggml_tensor * q_nope2_perm = ggml_permute(ctx0, q_nope2, 0, 2, 1, 3);
13506- cb(q_nope2_perm, "q_nope2_perm", il);
13507-
13508- struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2_perm);
13518+ if (!pp_opt) {
13519+ q_nope2 = ggml_permute(ctx0, q_nope2, 0, 2, 1, 3);
13520+ cb(q_nope2, "q_nope2_perm", il);
13521+ }
13522+ struct ggml_tensor * kq_nope = ggml_mul_mat(ctx0, kv_cache, q_nope2);
1350913523 cb(kq_nope, "kq_nope", il);
1351013524
13511- // Huh? This is not used anywhere
13512- //struct ggml_tensor * q_pe_perm = ggml_permute(ctx0, q_pe, 0, 3, 2, 1);
13513- //cb(q_pe_perm, "q_pe_perm", il);
13525+ if (!pp_opt) {
13526+ kq_nope = ggml_permute(ctx0, kq_nope, 0, 2, 1, 3);
13527+ cb(kq_nope, "kq_nope_perm", il);
13528+ }
1351413529
13530+ if (pp_opt) {
13531+ q_pe = ggml_permute(ctx0, q_pe, 0, 2, 1, 3);
13532+ cb(q_pe, "q_pe_perm", il);
13533+ }
1351513534 struct ggml_tensor * kq_pe = ggml_mul_mat(ctx0, kr_cache, q_pe);
1351613535 cb(kq_pe, "kq_pe", il);
1351713536
13537+ if (!pp_opt) {
13538+ kq_pe = ggml_permute(ctx0, kq_pe, 0, 2, 1, 3);
13539+ cb(kq_pe, "kq_pe_perm", il);
13540+ }
13541+
1351813542 struct ggml_tensor * kq = ggml_add(ctx0, kq_nope, kq_pe);
1351913543 cb(kq, "kq", il);
1352013544
13521- // We need this copy because soft_max expects a contiguous tensor
13522- kq = ggml_cont(ctx0, ggml_permute(ctx0, kq, 0, 2, 1, 3));
13523- cb(kq, "kq_perm", il);
13524-
1352513545 kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, kq_scale, hparams.f_max_alibi_bias);
1352613546 cb(kq, "kq_soft_max_ext", il);
1352713547
13528- struct ggml_tensor * kq_perm = ggml_permute(ctx0, kq, 0, 2, 1, 3);
13529- cb(kq_perm, "kq_soft_max_ext_perm", il);
13548+ if (!pp_opt) {
13549+ kq = ggml_permute(ctx0, kq, 0, 2, 1, 3);
13550+ cb(kq, "kq_soft_max_ext_perm", il);
13551+ }
1353013552
13531- struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq_perm );
13553+ struct ggml_tensor * kqv_compressed = ggml_mul_mat(ctx0, kv_cache_trans, kq );
1353213554 cb(kqv_compressed, "kqv_compressed", il);
1353313555
13534- kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
13535- cb(kqv_compressed, "kqv_compressed_perm", il);
13556+ if (!pp_opt) {
13557+ kqv_compressed = ggml_permute(ctx0, kqv_compressed, 0, 2, 1, 3);
13558+ cb(kqv_compressed, "kqv_compressed_perm", il);
13559+ }
1353613560
1353713561 struct ggml_tensor * wv_b = ggml_view_3d(ctx0, model.layers[il].wv_b, kv_lora_rank, n_embd_head_v, n_head, ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank), ggml_row_size(model.layers[il].wv_b->type, kv_lora_rank * n_embd_head_v), 0);
1353813562 cb(wv_b, "wv_b", il);
0 commit comments