From 01dc936074e003d893b7734bbd9360f31542281d Mon Sep 17 00:00:00 2001 From: lvyichen Date: Fri, 15 May 2026 19:48:27 +0800 Subject: [PATCH] metal: reuse K/V in flash-attn vec for spec-decode --- ggml/src/ggml-metal/ggml-metal-device.cpp | 11 +- ggml/src/ggml-metal/ggml-metal-device.h | 1 + ggml/src/ggml-metal/ggml-metal-impl.h | 4 + ggml/src/ggml-metal/ggml-metal-ops.cpp | 14 +- ggml/src/ggml-metal/ggml-metal.metal | 415 +++++++++++++--------- tests/test-backend-ops.cpp | 11 + 6 files changed, 286 insertions(+), 170 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index e288a27f992..02ab7c2810f 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -1437,6 +1437,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_v bool has_bias, bool has_scap, bool has_kvpad, + int32_t nqpsg, int32_t nsg, int32_t nwg) { assert(op->op == GGML_OP_FLASH_ATTN_EXT); @@ -1450,11 +1451,17 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_v const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0]; const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0]; - snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d", + char nq_suffix[8] = {0}; + if (nqpsg > 1) { + snprintf(nq_suffix, sizeof(nq_suffix), "_q%d", nqpsg); + } + + snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d%s", "flash_attn_ext_vec", ggml_type_name(op->src[1]->type), dk, - dv); + dv, + nq_suffix); snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d", base, diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 1f212a92f98..77f5526a187 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -192,6 +192,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_att bool has_bias, bool has_scap, bool has_kvpad, + int32_t nqpsg, int32_t nsg, int32_t nwg); diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index ff74cafb5b7..cde391dd823 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -109,6 +109,10 @@ #define OP_FLASH_ATTN_EXT_VEC_NQPSG 1 #define OP_FLASH_ATTN_EXT_VEC_NCPSG 32 +// minimum ne11 (KV length) for the dk=128 Q=2 vec specialization; +// below this the K/V reuse savings do not offset the extra register pressure +#define OP_FLASH_ATTN_EXT_VEC_Q2_DK128_MIN_KV 4096 + #define OP_UNARY_NUM_SCALE 10 #define OP_UNARY_NUM_FILL 11 #define OP_UNARY_NUM_CLAMP 12 diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index a114391c2e8..d869b295d93 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -10,6 +10,7 @@ #include #include +#include #include #include @@ -2871,7 +2872,14 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { #undef FATTN_SMEM } else { // half4x4 kernel - const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPSG; // queries per threadgroup + // Amortize one K/V tile read across two query rows. + static const bool disable_q2 = std::getenv("GGML_METAL_FA_DISABLE_Q2") != nullptr; + const bool can_q2 = !disable_q2 && + op->src[1]->type == GGML_TYPE_F16 && ne01 >= 2 && + ( (ne00 == 256 && ne20 == 256) || + (ne00 == 128 && ne20 == 128 && ne11 >= OP_FLASH_ATTN_EXT_VEC_Q2_DK128_MIN_KV) ); + + const int nqptg = can_q2 ? 2 : OP_FLASH_ATTN_EXT_VEC_NQPSG; // queries per threadgroup const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !! const int nhptg = 1; // heads per threadgroup @@ -2935,7 +2943,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { // ne20*(nsg) // each simdgroup has a full f32 head vector in shared mem to accumulate results // -#define FATTN_SMEM(nsg) (GGML_PAD(((GGML_PAD(ne00, 128) + 4*ncpsg + 2*GGML_PAD(ne20, 128))*(nsg))*(sizeof(float)/2), 16)) +#define FATTN_SMEM(nsg) (GGML_PAD(((GGML_PAD(ne00, 128) + 4*ncpsg + 2*GGML_PAD(ne20, 128))*(nsg)*nqptg)*(sizeof(float)/2), 16)) int64_t nsg = 1; @@ -2990,7 +2998,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.logit_softcap =*/ logit_softcap, }; - auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg); + auto pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nqptg, nsg, nwg); GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 82e29d5ad7c..800a5aa503c 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -6733,25 +6733,25 @@ kernel void kernel_flash_attn_ext_vec( constexpr short NW = N_SIMDWIDTH; constexpr short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads - constexpr short SH = 4*C; // shared memory per simdgroup + constexpr short SH = 4*Q*C; // shared memory per simdgroup static_assert(DK4 % NL == 0, "DK4 must be divisible by NL"); static_assert(DV4 % NL == 0, "DV4 must be divisible by NL"); //const short T = PK + NSG*SH; // shared memory size per query in (half) - //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t - threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + NSG*PK); // scratch buffer for attention - threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + NSG*PK); // same as above but in s4_t - threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*C + NSG*PK); // scratch buffer for mask - threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*PV + NSG*PK + NSG*SH); // scratch buffer for the results + //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*NSG*PK); // scratch buffer for attention + threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*NSG*PK); // same as above but in s4_t + threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + 2*Q*C + Q*NSG*PK); // scratch buffer for mask + threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 2*sgitg*Q*PV + Q*NSG*PK + NSG*SH); // scratch buffer for the results // store the result for all queries in shared memory (the O matrix from the paper) so4 += tiisg; { - q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03; + q += iq1*Q*args.nb01 + iq2*args.nb02 + iq3*args.nb03; const short ikv2 = iq2/(args.ne02/args.ne_12_2); const short ikv3 = iq3/(args.ne03/args.ne_12_3); @@ -6760,22 +6760,32 @@ kernel void kernel_flash_attn_ext_vec( v += ikv2*args.nb22 + ikv3*args.nb23; } - // load heads from Q to shared memory - device const float4 * q4 = (device const float4 *) ((device const char *) q); - - if (iq1 < args.ne01) { - for (short i = tiisg; i < PK4; i += NW) { - if (i < DK4) { - sq4[i] = (q4_t) q4[i]; + // load Q query rows to shared memory + { + for (short qq = 0; qq < Q; ++qq) { + const int iq1_q = iq1*Q + qq; + device const float4 * q4 = (device const float4 *) ((device const char *) q + qq*args.nb01); + if (iq1_q < args.ne01) { + for (short i = tiisg; i < PK4; i += NW) { + if (i < DK4) { + sq4[qq*PK4 + i] = (q4_t) q4[i]; + } else { + sq4[qq*PK4 + i] = (q4_t) 0.0f; + } + } } else { - sq4[i] = (q4_t) 0.0f; + for (short i = tiisg; i < PK4; i += NW) { + sq4[qq*PK4 + i] = (q4_t) 0.0f; + } } } } // zero out so - for (short i = 0; i < DV4/NL; ++i) { - so4[i*NL] = (o4_t) 0.0f; + for (short qq = 0; qq < Q; ++qq) { + for (short i = 0; i < DV4/NL; ++i) { + so4[qq*DV4 + i*NL] = (o4_t) 0.0f; + } } // zero out shared memory SH @@ -6786,15 +6796,19 @@ kernel void kernel_flash_attn_ext_vec( threadgroup_barrier(mem_flags::mem_threadgroup); { - float S = 0.0f; - float M = -FLT_MAX/2; + float S[Q]; + float M[Q]; + FOR_UNROLL (short qq = 0; qq < Q; ++qq) { + S[qq] = 0.0f; + M[qq] = -FLT_MAX/2; + } // thread indices inside the simdgroup const short tx = tiisg%NL; const short ty = tiisg/NL; // pointer to the mask - device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); + device const half * pm_base = (device const half *) (mask + iq1*Q*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); float slope = 1.0f; @@ -6816,6 +6830,13 @@ kernel void kernel_flash_attn_ext_vec( break; } + device const half * pm[Q]; + FOR_UNROLL (short qq = 0; qq < Q; ++qq) { + // padded query rows clamp to row 0 of the mask to avoid OOB; their scores + // are forced to -inf below, so the values never affect the result. + pm[qq] = pm_base + ((iq1*Q + qq) < args.ne01 ? qq*(args.nb31/sizeof(half)) : -iq1*Q*(args.nb31/sizeof(half))); + } + // the last partial chunk uses the pad buffer as source if (FC_flash_attn_ext_vec_has_kvpad && ic + C > args.ne11) { k = pad; @@ -6829,43 +6850,72 @@ kernel void kernel_flash_attn_ext_vec( v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C; if (!FC_flash_attn_ext_vec_has_mask) { - if (ic + tiisg >= args.ne11) { - sm[tiisg] = -MAXHALF; + FOR_UNROLL (short qq = 0; qq < Q; ++qq) { + if (ic + tiisg >= args.ne11) { + sm[qq*C + tiisg] = -MAXHALF; + } } } else { - pm = (device const half *) (mask) + - iq1*C + - (iq2%args.ne32)*(C*args.ne31) + - (iq3%args.ne33)*(C*args.ne31*args.ne32); + FOR_UNROLL (short qq = 0; qq < Q; ++qq) { + pm[qq] = (device const half *) (mask) + + (iq1*Q + qq)*C + + (iq2%args.ne32)*(C*args.ne31) + + (iq3%args.ne33)*(C*args.ne31*args.ne32); + } } ic = 0; } if (FC_flash_attn_ext_vec_has_mask) { - sm[tiisg] = pm[ic + tiisg]; + FOR_UNROLL (short qq = 0; qq < Q; ++qq) { + if ((iq1*Q + qq) < args.ne01) { + sm[qq*C + tiisg] = pm[qq][ic + tiisg]; + } else { + sm[qq*C + tiisg] = -MAXHALF; + } + } + } else { + FOR_UNROLL (short qq = 0; qq < Q; ++qq) { + if ((iq1*Q + qq) >= args.ne01) { + sm[qq*C + tiisg] = -MAXHALF; + } + } } - // skip -INF blocks - if (simd_max(sm[tiisg]) <= -MAXHALF) { - continue; + { + bool any_finite = false; + FOR_UNROLL (short qq = 0; qq < Q; ++qq) { + if (simd_max(sm[qq*C + tiisg]) > -MAXHALF) { + any_finite = true; + } + } + if (!any_finite) { + continue; + } } // Q*K^T { device const k4_t * pk4 = (device const k4_t *) (k + ic*args.nb11); - threadgroup const q4_t * pq4 = sq4; pk4 += ty*NS10/4 + tx; - pq4 += tx; - qk_t mqk[C/NE] = { [ 0 ... C/NE - 1] = 0.0f }; + qk_t mqk[Q][C/NE]; + FOR_UNROLL (short qq = 0; qq < Q; ++qq) { + FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { + mqk[qq][cc] = 0.0f; + } + } - // each simdgroup processes 1 query and NE (NW/NL) cache elements + // each simdgroup processes Q queries and NE (NW/NL) cache elements FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { if (is_same::value) { FOR_UNROLL (short ii = 0; ii < DK4/NL; ++ii) { - mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 + ii*NL], (float4) pq4[ii*NL]); + const k4_t k_elem = pk4[cc*NE*NS10/4 + ii*NL]; + FOR_UNROLL (short qq = 0; qq < Q; ++qq) { + mqk[qq][cc] += dot((float4) k_elem, (float4) sq4[qq*PK4 + ii*NL + tx]); + } } } else { device const kd4_t * pk = (device const kd4_t *) (k + ((ic + NE*cc + ty)*args.nb11)); @@ -6877,57 +6927,63 @@ kernel void kernel_flash_attn_ext_vec( deq_k_t4(pk + i/nl_k, i%nl_k, mk); - mqk[cc] += dot((float4) mk, (float4) sq4[i]); + FOR_UNROLL (short qq = 0; qq < Q; ++qq) { + mqk[qq][cc] += dot((float4) mk, (float4) sq4[qq*PK4 + i]); + } } } - if (NE == 1) { - mqk[cc] = simd_sum(mqk[cc]); - } else { - // simdgroup reduce (NE = 4) - // [ 0 .. 7] -> [ 0] - // [ 8 .. 15] -> [ 8] - // [16 .. 23] -> [16] - // [24 .. 31] -> [24] - if (NE <= 1) { - mqk[cc] += simd_shuffle_down(mqk[cc], 16); - } - if (NE <= 2) { - mqk[cc] += simd_shuffle_down(mqk[cc], 8); - } - if (NE <= 4) { - mqk[cc] += simd_shuffle_down(mqk[cc], 4); - } - if (NE <= 8) { - mqk[cc] += simd_shuffle_down(mqk[cc], 2); - } - if (NE <= 16) { - mqk[cc] += simd_shuffle_down(mqk[cc], 1); - } + FOR_UNROLL (short qq = 0; qq < Q; ++qq) { + if (NE == 1) { + mqk[qq][cc] = simd_sum(mqk[qq][cc]); + } else { + // simdgroup reduce (NE = 4) + // [ 0 .. 7] -> [ 0] + // [ 8 .. 15] -> [ 8] + // [16 .. 23] -> [16] + // [24 .. 31] -> [24] + if (NE <= 1) { + mqk[qq][cc] += simd_shuffle_down(mqk[qq][cc], 16); + } + if (NE <= 2) { + mqk[qq][cc] += simd_shuffle_down(mqk[qq][cc], 8); + } + if (NE <= 4) { + mqk[qq][cc] += simd_shuffle_down(mqk[qq][cc], 4); + } + if (NE <= 8) { + mqk[qq][cc] += simd_shuffle_down(mqk[qq][cc], 2); + } + if (NE <= 16) { + mqk[qq][cc] += simd_shuffle_down(mqk[qq][cc], 1); + } - // broadcast - mqk[cc] = simd_shuffle(mqk[cc], NL*ty); + // broadcast + mqk[qq][cc] = simd_shuffle(mqk[qq][cc], NL*ty); + } } } - if (FC_flash_attn_ext_vec_has_mask && - !FC_flash_attn_ext_vec_has_scap && - !FC_flash_attn_ext_vec_has_bias) { - ss[NE*tx + ty] = fma(mqk[tx], args.scale, (qk_t) sm[NE*tx + ty]); - } else { - mqk[tx] *= args.scale; + FOR_UNROLL (short qq = 0; qq < Q; ++qq) { + if (FC_flash_attn_ext_vec_has_mask && + !FC_flash_attn_ext_vec_has_scap && + !FC_flash_attn_ext_vec_has_bias) { + ss[qq*C + NE*tx + ty] = fma(mqk[qq][tx], args.scale, (qk_t) sm[qq*C + NE*tx + ty]); + } else { + mqk[qq][tx] *= args.scale; - if (FC_flash_attn_ext_vec_has_scap) { - mqk[tx] = args.logit_softcap*precise::tanh(mqk[tx]); - } + if (FC_flash_attn_ext_vec_has_scap) { + mqk[qq][tx] = args.logit_softcap*precise::tanh(mqk[qq][tx]); + } - if (FC_flash_attn_ext_vec_has_bias) { - mqk[tx] += (qk_t) sm[NE*tx + ty]*slope; - } else { - mqk[tx] += (qk_t) sm[NE*tx + ty]; - } + if (FC_flash_attn_ext_vec_has_bias) { + mqk[qq][tx] += (qk_t) sm[qq*C + NE*tx + ty]*slope; + } else { + mqk[qq][tx] += (qk_t) sm[qq*C + NE*tx + ty]; + } - ss[NE*tx + ty] = mqk[tx]; + ss[qq*C + NE*tx + ty] = mqk[qq][tx]; + } } } @@ -6935,23 +6991,25 @@ kernel void kernel_flash_attn_ext_vec( // online softmax { - const float m = M; - const float s = ss[tiisg]; + FOR_UNROLL (short qq = 0; qq < Q; ++qq) { + const float m = M[qq]; + const float s = ss[qq*C + tiisg]; - M = simd_max(max(M, s)); + M[qq] = simd_max(max(M[qq], s)); - const float ms = exp(m - M); - const float vs = exp(s - M); + const float ms = exp(m - M[qq]); + const float vs = exp(s - M[qq]); - S = S*ms + simd_sum(vs); + S[qq] = S[qq]*ms + simd_sum(vs); - // the P matrix from the paper (Q rows, C columns) - ss[tiisg] = vs; + // the P matrix from the paper (Q rows, C columns) + ss[qq*C + tiisg] = vs; - // O = diag(ms)*O - if ((DV4/NL % NW == 0) || ty == 0) { - FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { - so4[ii*NL] *= ms; + // O = diag(ms)*O + if ((DV4/NL % NW == 0) || ty == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + so4[qq*DV4 + ii*NL] *= ms; + } } } } @@ -6960,9 +7018,11 @@ kernel void kernel_flash_attn_ext_vec( // O = O + (Q*K^T)*V { - o4_t lo[DV4/NL]; - FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { - lo[ii] = 0.0f; + o4_t lo[Q][DV4/NL]; + FOR_UNROLL (short qq = 0; qq < Q; ++qq) { + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + lo[qq][ii] = 0.0f; + } } if (is_same::value) { @@ -6970,11 +7030,12 @@ kernel void kernel_flash_attn_ext_vec( pv4 += ty*NS20/4 + tx; - const auto sst = ss + ty; - FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { - lo[ii] += o4_t(float4(pv4[cc*NE*NS20/4 + ii*NL])*float4(sst[cc*NE])); + const v4_t v_elem = pv4[cc*NE*NS20/4 + ii*NL]; + FOR_UNROLL (short qq = 0; qq < Q; ++qq) { + lo[qq][ii] += o4_t(float4(v_elem)*float4(ss[qq*C + cc*NE + ty])); + } } } } else { @@ -6987,78 +7048,88 @@ kernel void kernel_flash_attn_ext_vec( v4_t mv; deq_v_t4(pv4 + i/nl_v, i%nl_v, mv); - lo[ii] += o4_t(float4(mv)*float4(ss[NE*cc + ty])); + FOR_UNROLL (short qq = 0; qq < Q; ++qq) { + lo[qq][ii] += o4_t(float4(mv)*float4(ss[qq*C + NE*cc + ty])); + } } } } - FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { - if (NE > 1) { - lo[ii][0] += simd_shuffle_down(lo[ii][0], 16); - lo[ii][1] += simd_shuffle_down(lo[ii][1], 16); - lo[ii][2] += simd_shuffle_down(lo[ii][2], 16); - lo[ii][3] += simd_shuffle_down(lo[ii][3], 16); - } + FOR_UNROLL (short qq = 0; qq < Q; ++qq) { + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + if (NE > 1) { + lo[qq][ii][0] += simd_shuffle_down(lo[qq][ii][0], 16); + lo[qq][ii][1] += simd_shuffle_down(lo[qq][ii][1], 16); + lo[qq][ii][2] += simd_shuffle_down(lo[qq][ii][2], 16); + lo[qq][ii][3] += simd_shuffle_down(lo[qq][ii][3], 16); + } - if (NE > 2) { - lo[ii][0] += simd_shuffle_down(lo[ii][0], 8); - lo[ii][1] += simd_shuffle_down(lo[ii][1], 8); - lo[ii][2] += simd_shuffle_down(lo[ii][2], 8); - lo[ii][3] += simd_shuffle_down(lo[ii][3], 8); - } + if (NE > 2) { + lo[qq][ii][0] += simd_shuffle_down(lo[qq][ii][0], 8); + lo[qq][ii][1] += simd_shuffle_down(lo[qq][ii][1], 8); + lo[qq][ii][2] += simd_shuffle_down(lo[qq][ii][2], 8); + lo[qq][ii][3] += simd_shuffle_down(lo[qq][ii][3], 8); + } - if (NE > 4) { - lo[ii][0] += simd_shuffle_down(lo[ii][0], 4); - lo[ii][1] += simd_shuffle_down(lo[ii][1], 4); - lo[ii][2] += simd_shuffle_down(lo[ii][2], 4); - lo[ii][3] += simd_shuffle_down(lo[ii][3], 4); - } + if (NE > 4) { + lo[qq][ii][0] += simd_shuffle_down(lo[qq][ii][0], 4); + lo[qq][ii][1] += simd_shuffle_down(lo[qq][ii][1], 4); + lo[qq][ii][2] += simd_shuffle_down(lo[qq][ii][2], 4); + lo[qq][ii][3] += simd_shuffle_down(lo[qq][ii][3], 4); + } - if (NE > 8) { - lo[ii][0] += simd_shuffle_down(lo[ii][0], 2); - lo[ii][1] += simd_shuffle_down(lo[ii][1], 2); - lo[ii][2] += simd_shuffle_down(lo[ii][2], 2); - lo[ii][3] += simd_shuffle_down(lo[ii][3], 2); - } + if (NE > 8) { + lo[qq][ii][0] += simd_shuffle_down(lo[qq][ii][0], 2); + lo[qq][ii][1] += simd_shuffle_down(lo[qq][ii][1], 2); + lo[qq][ii][2] += simd_shuffle_down(lo[qq][ii][2], 2); + lo[qq][ii][3] += simd_shuffle_down(lo[qq][ii][3], 2); + } - if (NE > 16) { - lo[ii][0] += simd_shuffle_down(lo[ii][0], 1); - lo[ii][1] += simd_shuffle_down(lo[ii][1], 1); - lo[ii][2] += simd_shuffle_down(lo[ii][2], 1); - lo[ii][3] += simd_shuffle_down(lo[ii][3], 1); + if (NE > 16) { + lo[qq][ii][0] += simd_shuffle_down(lo[qq][ii][0], 1); + lo[qq][ii][1] += simd_shuffle_down(lo[qq][ii][1], 1); + lo[qq][ii][2] += simd_shuffle_down(lo[qq][ii][2], 1); + lo[qq][ii][3] += simd_shuffle_down(lo[qq][ii][3], 1); + } } } if ((DV4/NL % NW == 0) || ty == 0) { - FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { - so4[ii*NL] += lo[ii]; + FOR_UNROLL (short qq = 0; qq < Q; ++qq) { + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + so4[qq*DV4 + ii*NL] += lo[qq][ii]; + } } } } } if (FC_flash_attn_ext_vec_has_sinks && sgitg == 0 && iwg == 0) { - const float m = M; - const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2; + FOR_UNROLL (short qq = 0; qq < Q; ++qq) { + const float m = M[qq]; + const float s = tiisg == 0 ? ((device const float *) sinks)[iq2] : -FLT_MAX/2; - M = simd_max(max(M, s)); + M[qq] = simd_max(max(M[qq], s)); - const float ms = exp(m - M); - const float vs = exp(s - M); + const float ms = exp(m - M[qq]); + const float vs = exp(s - M[qq]); - S = S*ms + simd_sum(vs); + S[qq] = S[qq]*ms + simd_sum(vs); - if ((DV4/NL % NW == 0) || ty == 0) { - FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { - so4[ii*NL] *= ms; + if ((DV4/NL % NW == 0) || ty == 0) { + FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { + so4[qq*DV4 + ii*NL] *= ms; + } } } } // these are needed for reducing the results from the simdgroups (reuse the ss buffer) if (tiisg == 0) { - ss[0] = (s_t) S; - ss[1] = (s_t) M; + FOR_UNROLL (short qq = 0; qq < Q; ++qq) { + ss[2*qq + 0] = (s_t) S[qq]; + ss[2*qq + 1] = (s_t) M[qq]; + } } } @@ -7069,27 +7140,29 @@ kernel void kernel_flash_attn_ext_vec( // parallel reduce for (short r = NSG/2; r > 0; r >>= 1) { if (sgitg < r) { - const float S0 = ss[ 0]; - const float S1 = ss[r*(SH/2) + 0]; + FOR_UNROLL (short qq = 0; qq < Q; ++qq) { + const float S0 = ss[ 2*qq + 0]; + const float S1 = ss[r*(SH/2) + 2*qq + 0]; - const float M0 = ss[ 1]; - const float M1 = ss[r*(SH/2) + 1]; + const float M0 = ss[ 2*qq + 1]; + const float M1 = ss[r*(SH/2) + 2*qq + 1]; - const float M = max(M0, M1); + const float Mx = max(M0, M1); - const float ms0 = exp(M0 - M); - const float ms1 = exp(M1 - M); + const float ms0 = exp(M0 - Mx); + const float ms1 = exp(M1 - Mx); - const float S = S0*ms0 + S1*ms1; + const float Sx = S0*ms0 + S1*ms1; - if (tiisg == 0) { - ss[0] = S; - ss[1] = M; - } + if (tiisg == 0) { + ss[2*qq + 0] = Sx; + ss[2*qq + 1] = Mx; + } - // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 - for (short i = tiisg; i < DV4; i += NW) { - so4[i] = so4[i]*ms0 + so4[i + r*PV4]*ms1; + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (short i = tiisg; i < DV4; i += NW) { + so4[qq*DV4 + i] = so4[qq*DV4 + i]*ms0 + so4[qq*DV4 + i + r*Q*PV4]*ms1; + } } } @@ -7099,23 +7172,31 @@ kernel void kernel_flash_attn_ext_vec( // final rescale with 1/S and store to global memory if (sgitg == 0) { const int64_t nrows = args.ne3*args.ne2*args.ne1; - const int64_t rid = iq3*args.ne2*args.ne1 + iq2 + iq1*args.ne1; device float4 * dst4 = (device float4 *) dst; device float * dst1 = (device float *) dst + nrows*DV*NWG; // the S and M are stored after the results - const float S = NWG == 1 ? (ss[0] == 0.0f ? 0.0f : 1.0f/ss[0]) : 1.0f; + FOR_UNROLL (short qq = 0; qq < Q; ++qq) { + const int iq1_q = iq1*Q + qq; + if (iq1_q >= args.ne01) { + continue; + } - // interleave the workgroup data - for (short i = tiisg; i < DV4; i += NW) { - dst4[rid*DV4*NWG + NWG*i + iwg] = (float4) so4[i]*S; - } + const int64_t rid = iq3*args.ne2*args.ne1 + iq2 + iq1_q*args.ne1; - // store S and M - if (NWG > 1) { - if (tiisg == 0) { - dst1[rid*(2*NWG) + 2*iwg + 0] = ss[0]; - dst1[rid*(2*NWG) + 2*iwg + 1] = ss[1]; + const float Sval = NWG == 1 ? (ss[2*qq + 0] == 0.0f ? 0.0f : 1.0f/ss[2*qq + 0]) : 1.0f; + + // interleave the workgroup data + for (short i = tiisg; i < DV4; i += NW) { + dst4[rid*DV4*NWG + NWG*i + iwg] = (float4) so4[qq*DV4 + i]*Sval; + } + + // store S and M + if (NWG > 1) { + if (tiisg == 0) { + dst1[rid*(2*NWG) + 2*iwg + 0] = ss[2*qq + 0]; + dst1[rid*(2*NWG) + 2*iwg + 1] = ss[2*qq + 1]; + } } } } @@ -7191,6 +7272,8 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk128_dv128")]] kernel flas template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk128_dv128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk128_dv128_q2")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + template [[host_name("kernel_flash_attn_ext_vec_f32_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; template [[host_name("kernel_flash_attn_ext_vec_f16_dk192_dv192")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_HAS_BF16) @@ -7224,6 +7307,8 @@ template [[host_name("kernel_flash_attn_ext_vec_q5_0_dk256_dv256")]] kernel flas template [[host_name("kernel_flash_attn_ext_vec_q5_1_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; template [[host_name("kernel_flash_attn_ext_vec_q8_0_dk256_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; +template [[host_name("kernel_flash_attn_ext_vec_f16_dk256_dv256_q2")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; + template [[host_name("kernel_flash_attn_ext_vec_f32_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; template [[host_name("kernel_flash_attn_ext_vec_f16_dk320_dv256")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec; #if defined(GGML_METAL_HAS_BF16) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 76f7cb5a867..f37d373e13a 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -9348,6 +9348,17 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 4, 128, 1024, 1)); // 4h PP-1024 test_cases.emplace_back(new test_gated_delta_net(GGML_TYPE_F32, 32, 128, 64, 1, 1, false, true)); // KDA PP-64 + for (int hsk : { 128, 256 }) { + for (int kv : { 4096, 4097, 8192, 16384 }) { + for (int nb : { 1, 2, 3, 5, 8 }) { + test_cases.emplace_back(new test_flash_attn_ext( + hsk, hsk, 4, {hsk == 128 ? 8 : 6, 1}, kv, nb, + true, false, 0.0f, 0.0f, + GGML_PREC_F32, GGML_TYPE_F16, GGML_TYPE_F16)); + } + } + } + return test_cases; }