@@ -627,6 +627,7 @@ struct vk_flash_attn_push_constants {
627627 uint32_t nev2;
628628 uint32_t nev3;
629629 uint32_t nem1;
630+ uint32_t nem2;
630631
631632 uint32_t nb01;
632633 uint32_t nb02;
@@ -637,7 +638,6 @@ struct vk_flash_attn_push_constants {
637638 uint32_t nb21;
638639 uint32_t nb22;
639640 uint32_t nb23;
640- uint32_t nb31;
641641
642642 float scale;
643643 float max_bias;
@@ -652,6 +652,7 @@ struct vk_flash_attn_push_constants {
652652 uint32_t split_kv;
653653 uint32_t k_num;
654654};
655+ static_assert(sizeof(vk_flash_attn_push_constants) <= 128, "sizeof(vk_flash_attn_push_constants) must be <= 128");
655656
656657struct vk_op_push_constants {
657658 uint32_t KX;
@@ -743,6 +744,14 @@ struct vk_op_rope_push_constants {
743744struct vk_op_soft_max_push_constants {
744745 uint32_t KX;
745746 uint32_t KY;
747+ uint32_t ne00;
748+ uint32_t ne01;
749+ uint32_t ne02;
750+ uint32_t ne12;
751+ uint32_t ne13;
752+ uint32_t nb11;
753+ uint32_t nb12;
754+ uint32_t nb13;
746755 float scale;
747756 float max_bias;
748757 float m0;
@@ -5977,7 +5986,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
59775986 GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
59785987
59795988 const uint32_t nem1 = mask ? mask->ne[1] : 0;
5980- const uint32_t nbm1 = mask ? mask->nb[1 ] : 0;
5989+ const uint32_t nem2 = mask ? mask->ne[2 ] : 0;
59815990
59825991 const uint32_t D = neq0;
59835992 uint32_t N = neq1;
@@ -6140,7 +6149,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
61406149 // Try to use split_k when KV is large enough to be worth the overhead
61416150 if (workgroups_x == 1 && shader_core_count > 0 && KV >= 512) {
61426151 // Try to run two workgroups per SM.
6143- split_k = ctx->device->shader_core_count * 2 / workgroups_y;
6152+ split_k = ctx->device->shader_core_count * 2 / ( workgroups_y * workgroups_z) ;
61446153 if (split_k > 1) {
61456154 // Try to evenly split KV into split_k chunks, but it needs to be a multiple
61466155 // of "align", so recompute split_k based on that.
@@ -6150,9 +6159,9 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
61506159 }
61516160 }
61526161
6153- // Reserve space for split_k temporaries. For each split, we need to store the O matrix (D x ne1)
6154- // and the per-row m and L values (ne1 rows).
6155- const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k : 0;
6162+ // Reserve space for split_k temporaries. For each split x batch , we need to store the O matrix (D x ne1)
6163+ // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows.
6164+ const uint64_t split_k_size = split_k > 1 ? (D * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0;
61566165 if (split_k_size > ctx->device->max_memory_allocation_size) {
61576166 GGML_ABORT("Requested preallocation size is too large");
61586167 }
@@ -6244,11 +6253,10 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
62446253 (uint32_t)neq2, (uint32_t)neq3,
62456254 (uint32_t)nek2, (uint32_t)nek3,
62466255 (uint32_t)nev2, (uint32_t)nev3,
6247- nem1,
6256+ nem1, nem2,
62486257 q_stride, (uint32_t)nbq2, (uint32_t)nbq3,
62496258 k_stride, (uint32_t)nbk2, (uint32_t)nbk3,
62506259 v_stride, (uint32_t)nbv2, (uint32_t)nbv3,
6251- nbm1,
62526260 scale, max_bias, logit_softcap,
62536261 mask != nullptr, n_head_log2, m0, m1,
62546262 gqa_ratio, split_kv, split_k };
@@ -6271,13 +6279,13 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
62716279 pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z });
62726280
62736281 ggml_vk_sync_buffers(subctx);
6274- const std::array<uint32_t, 3 > pc2 = { D, (uint32_t)ne1, split_k };
6282+ const std::array<uint32_t, 4 > pc2 = { D, (uint32_t)ne1, (uint32_t)ne3 , split_k };
62756283 ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce,
62766284 {
62776285 vk_subbuffer{ctx->prealloc_split_k, 0, VK_WHOLE_SIZE},
62786286 vk_subbuffer{d_D, d_buf_offset, VK_WHOLE_SIZE},
62796287 },
6280- pc2, { (uint32_t)ne1, 1, 1 });
6288+ pc2, { (uint32_t)ne1, 1, (uint32_t)ne3 });
62816289 } else {
62826290 ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
62836291 {
@@ -7562,7 +7570,13 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
75627570 const uint32_t nrows_x = (uint32_t)ggml_nrows(src0);
75637571 const uint32_t nrows_y = (uint32_t)src0->ne[1];
75647572
7565- const uint32_t n_head_kv = nrows_x/nrows_y;
7573+ const uint32_t ne12 = src1 ? (uint32_t)(src1->ne[2]) : 0u;
7574+ const uint32_t ne13 = src1 ? (uint32_t)(src1->ne[3]) : 0u;
7575+ const uint32_t nb11 = src1 ? (uint32_t)(src1->nb[1] / src1->nb[0]) : 0u;
7576+ const uint32_t nb12 = src1 ? (uint32_t)(src1->nb[2] / src1->nb[0]) : 0u;
7577+ const uint32_t nb13 = src1 ? (uint32_t)(src1->nb[3] / src1->nb[0]) : 0u;
7578+
7579+ const uint32_t n_head_kv = src0->ne[2];
75667580 const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
75677581
75687582 const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@@ -7571,6 +7585,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
75717585 ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX, {
75727586 ncols,
75737587 src1 != nullptr ? nrows_y : (uint32_t)0,
7588+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],
7589+ ne12, ne13,
7590+ nb11, nb12, nb13,
75747591 scale, max_bias,
75757592 m0, m1,
75767593 n_head_log2,
@@ -10066,11 +10083,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1006610083 if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) {
1006710084 return false;
1006810085 }
10069- // TODO: support broadcast
10070- // ref: https://github.com/ggml-org/llama.cpp/pull/14435
10071- if (op->src[0]->ne[3] != 1) {
10072- return false;
10073- }
1007410086 // It's straightforward to support different K/V dequant, but would
1007510087 // significantly increase the number of pipelines
1007610088 if (op->src[1]->type != op->src[2]->type) {
@@ -10231,13 +10243,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
1023110243 case GGML_OP_DIAG_MASK_INF:
1023210244 return true;
1023310245 case GGML_OP_SOFT_MAX:
10234- // TODO: support batching
10235- if (op->src[0]->ne[3] != 1) {
10236- return false;
10237- }
10238- // TODO: support broadcast
10239- // ref: https://github.com/ggml-org/llama.cpp/pull/14435
10240- return !op->src[1] || (op->src[1]->ne[2] == 1 && op->src[1]->ne[3] == 1);
1024110246 case GGML_OP_SOFT_MAX_BACK:
1024210247 case GGML_OP_ARGSORT:
1024310248 case GGML_OP_SUM:
0 commit comments