From fa52517f91d3124bcf19a3eb383f5068886eccc8 Mon Sep 17 00:00:00 2001 From: Paul Flynn Date: Sun, 8 Mar 2026 14:04:03 -0400 Subject: [PATCH 1/5] metal : add Metal backend for GGML_OP_GATED_DELTA_NET Add a fused Metal kernel for the gated delta net recurrence op (#19504), enabling GPU-accelerated inference for DeltaNet-based models (Qwen3.5, etc.) on Apple Silicon. Supports both GDA (scalar gate) and KDA (per-row gate) modes with head_size 64 and 128. Unsupported configurations (head_size 32, non-contiguous tensors) gracefully fall back to CPU. Performance: Qwen3.5-0.8B Q4_K_M on M4 Max tg128: 170 -> 213 t/s (+25%) Co-Authored-By: Claude Opus 4.6 --- ggml/src/ggml-metal/ggml-metal-device.cpp | 20 +++ ggml/src/ggml-metal/ggml-metal-device.h | 1 + ggml/src/ggml-metal/ggml-metal-device.m | 9 ++ ggml/src/ggml-metal/ggml-metal-ops.cpp | 41 +++++++ ggml/src/ggml-metal/ggml-metal-ops.h | 1 + ggml/src/ggml-metal/ggml-metal.metal | 142 ++++++++++++++++++++++ 6 files changed, 214 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 169c63dd7a4..8079e165032 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -577,6 +577,26 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_ return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(ggml_metal_library_t lib, const ggml_tensor * op) { + // v is src[2], dimensions: S_v = ne[0], H = ne[1] + const int64_t S_v = op->src[2]->ne[0]; + const int64_t H = op->src[2]->ne[1]; + const int64_t C = op->ne[0]; + + GGML_ASSERT(op->src[5]->type == GGML_TYPE_F32); + GGML_ASSERT(C == S_v * H); + GGML_ASSERT(S_v == 64 || S_v == 128); + + const char * name = "kernel_gated_delta_net_f32"; + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + res = ggml_metal_library_compile_pipeline(lib, name, name, nullptr); + } + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) { char base[256]; char name[256]; diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 93d7f6a216f..fd2b3ddeb55 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -125,6 +125,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 23bd2b2ab72..62ac46ab162 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1155,6 +1155,15 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: return true; + case GGML_OP_GATED_DELTA_NET: + { + // Metal kernel supports head_size 64 and 128, contiguous tensors only + const int64_t S_v = op->src[2]->ne[0]; + return (S_v == 64 || S_v == 128) + && ggml_is_contiguous(op->src[0]) + && ggml_is_contiguous(op->src[1]) + && ggml_is_contiguous(op->src[2]); + } case GGML_OP_SOLVE_TRI: case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 267755d08cc..c913ec4f307 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -333,6 +333,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_rwkv(ctx, idx); } break; + case GGML_OP_GATED_DELTA_NET: + { + n_fuse = ggml_metal_op_gated_delta_net(ctx, idx); + } break; case GGML_OP_SOLVE_TRI: { n_fuse = ggml_metal_op_solve_tri(ctx, idx); @@ -1562,6 +1566,43 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_gated_delta_net(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + // src[0]=q, src[1]=k, src[2]=v, src[3]=gate, src[4]=beta, src[5]=state + // Dimensions from v (src[2]): S_v=ne[0], H=ne[1], n_tokens=ne[2], n_seqs=ne[3] + const int64_t B = op->src[2]->ne[3]; // n_seqs + const int64_t T = op->src[2]->ne[2] * B; // total tokens + const int64_t C = op->ne[0]; // S_v * H + const int64_t H = op->src[2]->ne[1]; // num heads + const int64_t G = op->src[3]->ne[0]; // gate ne[0]: 1=GDA, S_v=KDA + + auto pipeline = ggml_metal_library_get_pipeline_gated_delta_net(lib, op); + + int ida = 0; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); // q + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); // k + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); // v + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); // gate + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); // beta + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++); // state + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++); // dst + ggml_metal_encoder_set_bytes (enc, (void *) &B, sizeof(B), ida++); + ggml_metal_encoder_set_bytes (enc, (void *) &T, sizeof(T), ida++); + ggml_metal_encoder_set_bytes (enc, (void *) &C, sizeof(C), ida++); + ggml_metal_encoder_set_bytes (enc, (void *) &H, sizeof(H), ida++); + ggml_metal_encoder_set_bytes (enc, (void *) &G, sizeof(G), ida++); + + ggml_metal_encoder_dispatch_threadgroups(enc, B * H, 1, 1, C/H, 1, 1); + + return 1; +} + int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index f3e38c7aa9d..019f2fec9ed 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -58,6 +58,7 @@ int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_gated_delta_net (ggml_metal_op_t ctx, int idx); int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx); int ggml_metal_op_set (ggml_metal_op_t ctx, int idx); int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 82ebbb4e409..75c97753ed1 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2434,6 +2434,148 @@ kernel void kernel_rwkv_wkv7_f32( } } +// Gated DeltaNet fused recurrence kernel (GDA and KDA gate modes) +// +// State layout: CPU kernel uses row-major M[i][j] at offset i*S+j. +// Thread tid owns column tid of M: state[j] = M[j][tid] (row j, col tid). +// This gives coalesced loads (consecutive threads read consecutive addresses). +// +// Gate semantics: +// GDA (G=1): scalar gate, M[i][j] *= exp(g) for all i,j +// KDA (G=S): per-row gate, M[i][j] *= exp(g[i]) for all j +// => thread tid must scale state[j] by exp(g[j]) since state[j] = M[j][tid] +// +// Grid: (B*H, 1, 1) threadgroups, (head_size, 1, 1) threads per threadgroup +// +// src layout (matches GGML_OP_GATED_DELTA_NET): +// src[0]=q, src[1]=k, src[2]=v, src[3]=gate, src[4]=beta, src[5]=state +// Dimensions from v: S_v=ne[0], H=ne[1], n_tokens=ne[2], n_seqs=ne[3] +// gate: [1,H,T,B] (GDA) or [S_v,H,T,B] (KDA) +kernel void kernel_gated_delta_net_f32( + device const float * q, + device const float * k, + device const float * v, + device const float * gate, // [G, H, T, B] log-space; G=1 (GDA) or G=S (KDA) + device const float * beta, // [1, H, T, B] + device const float * state_in, // [S*S*H, n_seqs] row-major per head + device float * dst, // [S*H, n_tokens*n_seqs + S*n_seqs] + constant uint & B, // n_seqs + constant uint & T, // n_tokens * n_seqs (total tokens) + constant uint & C, // S * H + constant uint & H, // num heads + constant uint & G, // gate ne[0]: 1=GDA, S=KDA + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + + const uint head_size = C / H; + const uint batch_id = tgpig.x / H; + const uint head_id = tgpig.x % H; + const uint tid = tpitg.x; + + if (batch_id >= B || head_id >= H || tid >= head_size) { + return; + } + + const float scale = 1.0f / sqrt((float)head_size); + + const uint state_size = C * head_size; // S * S * H + const uint n_seq_tokens = T / B; + + // Max head_size is 128 (enforced by pipeline getter). + threadgroup float _k[128]; + threadgroup float _q[128]; + threadgroup float _g[128]; // gate vector for KDA mode + + // Load initial state: thread tid owns column tid of M (row-major) + // M[row=j][col=tid] at offset j*head_size+tid (row-major, coalesced) + float state[128]; + for (uint j = 0; j < head_size; j++) { + state[j] = state_in[batch_id * state_size + head_id * head_size * head_size + + j * head_size + tid]; + } + + // Process tokens sequentially + for (uint tt = 0; tt < n_seq_tokens; tt++) { + const uint t_abs = batch_id * n_seq_tokens + tt; + const uint kv_offset = t_abs * C + head_id * head_size; + const uint gb_offset = t_abs * H + head_id; + + // Load k, q (and gate for KDA) into shared memory + threadgroup_barrier(mem_flags::mem_threadgroup); + _k[tid] = k[kv_offset + tid]; + _q[tid] = q[kv_offset + tid]; + if (G > 1) { + _g[tid] = exp(min(gate[t_abs * H * head_size + head_id * head_size + tid], 88.0f)); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + const float beta_val = beta[gb_offset]; + const float v_tid = v[kv_offset + tid]; + + // Decay state and compute sk = sum_j M[j][tid] * k[j] + // GDA: all elements scaled by same exp(g) + // KDA: state[j] = M[j][tid] scaled by exp(g[j]) (per-row gate) + float sk = 0.0f; + if (G == 1) { + const float g_exp = exp(min(gate[gb_offset], 88.0f)); + for (uint j = 0; j < head_size; j += 4) { + float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]); + float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + s_vec *= g_exp; + sk += dot(s_vec, k_vec); + state[j] = s_vec[0]; + state[j+1] = s_vec[1]; + state[j+2] = s_vec[2]; + state[j+3] = s_vec[3]; + } + } else { + for (uint j = 0; j < head_size; j += 4) { + float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]); + float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + float4 g_vec = float4(_g[j], _g[j+1], _g[j+2], _g[j+3]); + s_vec *= g_vec; + sk += dot(s_vec, k_vec); + state[j] = s_vec[0]; + state[j+1] = s_vec[1]; + state[j+2] = s_vec[2]; + state[j+3] = s_vec[3]; + } + } + + // Delta: d = (v[tid] - sk) * beta + // Note: delta is per-column (thread-local), matching CPU's delta[j] = (v[j] - sk[j]) * beta + // Here sk is sum_j M[j][tid]*k[j], and v_tid = v[tid], so d = delta for column tid + float d = (v_tid - sk) * beta_val; + + // State update: M[j][tid] += k[j] * d (rank-1 outer product k * delta^T) + for (uint j = 0; j < head_size; j += 4) { + float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]); + float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + s_vec += k_vec * d; + state[j] = s_vec[0]; + state[j+1] = s_vec[1]; + state[j+2] = s_vec[2]; + state[j+3] = s_vec[3]; + } + + // Output: o[tid] = sum_j M[j][tid] * q[j] * scale + float y = 0.0f; + for (uint j = 0; j < head_size; j += 4) { + float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]); + float4 q_vec = float4(_q[j], _q[j+1], _q[j+2], _q[j+3]); + y += dot(s_vec, q_vec); + } + dst[t_abs * C + head_id * head_size + tid] = y * scale; + } + + // Write final state (row-major: M[j][tid] at j*head_size+tid) + for (uint j = 0; j < head_size; j++) { + dst[T * C + batch_id * state_size + head_id * head_size * head_size + + j * head_size + tid] = state[j]; + } +} + constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]]; constant short FC_solve_tri_n [[function_constant(FC_SOLVE_TRI + 1)]]; constant short FC_solve_tri_k [[function_constant(FC_SOLVE_TRI + 2)]]; From e2d870b9c4f6a258a8eefca2239b939d675851ce Mon Sep 17 00:00:00 2001 From: Paul Flynn Date: Sun, 8 Mar 2026 14:08:13 -0400 Subject: [PATCH 2/5] metal : validate contiguity of all input tensors in supports_op Co-Authored-By: Claude Opus 4.6 --- ggml/src/ggml-metal/ggml-metal-device.m | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 62ac46ab162..4b4fec2fc85 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1162,7 +1162,10 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te return (S_v == 64 || S_v == 128) && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) - && ggml_is_contiguous(op->src[2]); + && ggml_is_contiguous(op->src[2]) + && ggml_is_contiguous(op->src[3]) + && ggml_is_contiguous(op->src[4]) + && ggml_is_contiguous(op->src[5]); } case GGML_OP_SOLVE_TRI: case GGML_OP_MUL_MAT: From 000a1174aa8f42a18ed08ec8e833f9c415397423 Mon Sep 17 00:00:00 2001 From: Paul Flynn Date: Sun, 8 Mar 2026 14:17:42 -0400 Subject: [PATCH 3/5] metal : add algorithm equivalence comment for GDA decay path Co-Authored-By: Claude Opus 4.6 --- ggml/src/ggml-metal/ggml-metal.metal | 3 +++ 1 file changed, 3 insertions(+) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 75c97753ed1..4c020b3d066 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2514,6 +2514,9 @@ kernel void kernel_gated_delta_net_f32( const float v_tid = v[kv_offset + tid]; // Decay state and compute sk = sum_j M[j][tid] * k[j] + // Two-pass approach: decay state first, then compute on decayed state. + // Algebraically equivalent to CUDA's fused form: delta = (v - g*dot(S,k))*beta; + // S = g*S + k*delta, since dot(g*S, k) = g*dot(S, k). // GDA: all elements scaled by same exp(g) // KDA: state[j] = M[j][tid] scaled by exp(g[j]) (per-row gate) float sk = 0.0f; From 02225ea1aa95f2541b0e929ff4faad443563fd73 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 10 Mar 2026 15:48:16 +0200 Subject: [PATCH 4/5] cont : unslop + optimize --- ggml/src/ggml-cpu/ops.cpp | 3 +- ggml/src/ggml-metal/ggml-metal-device.cpp | 29 ++- ggml/src/ggml-metal/ggml-metal-device.m | 12 +- ggml/src/ggml-metal/ggml-metal-impl.h | 39 ++++ ggml/src/ggml-metal/ggml-metal-ops.cpp | 64 +++++-- ggml/src/ggml-metal/ggml-metal.metal | 210 ++++++++++------------ 6 files changed, 205 insertions(+), 152 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index d324128c893..ffa272b8da7 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10468,7 +10468,7 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1); const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1); - const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1); + const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1); if (kda) { for (int64_t i = 0; i < S_v; ++i) { @@ -10501,7 +10501,6 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( attn_data += S_v * H; // advance to next token } - } } diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 8079e165032..15ae2e517df 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -578,22 +578,37 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_ } ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(ggml_metal_library_t lib, const ggml_tensor * op) { + char base[256]; + char name[256]; + // v is src[2], dimensions: S_v = ne[0], H = ne[1] - const int64_t S_v = op->src[2]->ne[0]; - const int64_t H = op->src[2]->ne[1]; - const int64_t C = op->ne[0]; + const int ne20 = op->src[2]->ne[0]; // S_v + const int ne21 = op->src[2]->ne[1]; // H + const int ne30 = op->src[3]->ne[0]; // G + + const int nsg = op->src[2]->ne[0]/32; GGML_ASSERT(op->src[5]->type == GGML_TYPE_F32); - GGML_ASSERT(C == S_v * H); - GGML_ASSERT(S_v == 64 || S_v == 128); + GGML_ASSERT(op->ne[0] == ne20 * ne21); + GGML_ASSERT(ne20 % 32 == 0); - const char * name = "kernel_gated_delta_net_f32"; + snprintf(base, 256, "kernel_gated_delta_net_%s_%d", ggml_type_name(op->src[0]->type), nsg); + snprintf(name, 256, "%s_ne20=%d_ne30=%d", base, ne20, ne30); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { - res = ggml_metal_library_compile_pipeline(lib, name, name, nullptr); + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, ne20, FC_GATED_DELTA_NET + 0); + ggml_metal_cv_set_int16(cv, ne30, FC_GATED_DELTA_NET + 1); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } + res.nsg = nsg; + return res; } diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 4b4fec2fc85..a4b176841ce 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1156,17 +1156,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_RWKV_WKV7: return true; case GGML_OP_GATED_DELTA_NET: - { - // Metal kernel supports head_size 64 and 128, contiguous tensors only - const int64_t S_v = op->src[2]->ne[0]; - return (S_v == 64 || S_v == 128) - && ggml_is_contiguous(op->src[0]) - && ggml_is_contiguous(op->src[1]) - && ggml_is_contiguous(op->src[2]) - && ggml_is_contiguous(op->src[3]) - && ggml_is_contiguous(op->src[4]) - && ggml_is_contiguous(op->src[5]); - } + return op->src[2]->ne[0] % 32 == 0; case GGML_OP_SOLVE_TRI: case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index bf51055e367..82dc2c728a1 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -84,6 +84,7 @@ #define FC_BIN 1300 #define FC_SUM_ROWS 1400 #define FC_UPSCALE 1500 +#define FC_GATED_DELTA_NET 1600 // op-specific constants #define OP_FLASH_ATTN_EXT_NQPSG 8 @@ -793,6 +794,44 @@ typedef struct { uint64_t nb0; } ggml_metal_kargs_ssm_scan; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne20; + int32_t ne21; + int32_t ne22; + int32_t ne23; + uint64_t nb20; + uint64_t nb21; + uint64_t nb22; + uint64_t nb23; + int32_t ns02; + int32_t ns12; + int32_t ns22; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_gated_delta_net; + typedef struct { int32_t ne00; int32_t ne01; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index c913ec4f307..306dbcf3660 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -1572,19 +1572,60 @@ int ggml_metal_op_gated_delta_net(ggml_metal_op_t ctx, int idx) { ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; - // src[0]=q, src[1]=k, src[2]=v, src[3]=gate, src[4]=beta, src[5]=state - // Dimensions from v (src[2]): S_v=ne[0], H=ne[1], n_tokens=ne[2], n_seqs=ne[3] - const int64_t B = op->src[2]->ne[3]; // n_seqs - const int64_t T = op->src[2]->ne[2] * B; // total tokens - const int64_t C = op->ne[0]; // S_v * H - const int64_t H = op->src[2]->ne[1]; // num heads - const int64_t G = op->src[3]->ne[0]; // gate ne[0]: 1=GDA, S_v=KDA + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); auto pipeline = ggml_metal_library_get_pipeline_gated_delta_net(lib, op); int ida = 0; + ggml_metal_kargs_gated_delta_net args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne20 =*/ ne20, + /*.ne21 =*/ ne21, + /*.ne22 =*/ ne22, + /*.ne23 =*/ ne23, + /*.nb20 =*/ nb20, + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb23 =*/ nb23, + /*.ns02 =*/ (int32_t) (nb02/sizeof(float)), + /*.ns12 =*/ (int32_t) (nb12/sizeof(float)), + /*.ns22 =*/ (int32_t) (nb22/sizeof(float)), + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); // q ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); // k ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); // v @@ -1592,13 +1633,10 @@ int ggml_metal_op_gated_delta_net(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); // beta ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++); // state ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++); // dst - ggml_metal_encoder_set_bytes (enc, (void *) &B, sizeof(B), ida++); - ggml_metal_encoder_set_bytes (enc, (void *) &T, sizeof(T), ida++); - ggml_metal_encoder_set_bytes (enc, (void *) &C, sizeof(C), ida++); - ggml_metal_encoder_set_bytes (enc, (void *) &H, sizeof(H), ida++); - ggml_metal_encoder_set_bytes (enc, (void *) &G, sizeof(G), ida++); - ggml_metal_encoder_dispatch_threadgroups(enc, B * H, 1, 1, C/H, 1, 1); + const int nsg = pipeline.nsg; + + ggml_metal_encoder_dispatch_threadgroups(enc, op->src[2]->ne[0]/nsg, op->src[2]->ne[1], op->src[2]->ne[3], 32, nsg, 1); return 1; } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 4c020b3d066..987b230aedd 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2434,151 +2434,123 @@ kernel void kernel_rwkv_wkv7_f32( } } -// Gated DeltaNet fused recurrence kernel (GDA and KDA gate modes) -// -// State layout: CPU kernel uses row-major M[i][j] at offset i*S+j. -// Thread tid owns column tid of M: state[j] = M[j][tid] (row j, col tid). -// This gives coalesced loads (consecutive threads read consecutive addresses). -// -// Gate semantics: -// GDA (G=1): scalar gate, M[i][j] *= exp(g) for all i,j -// KDA (G=S): per-row gate, M[i][j] *= exp(g[i]) for all j -// => thread tid must scale state[j] by exp(g[j]) since state[j] = M[j][tid] -// -// Grid: (B*H, 1, 1) threadgroups, (head_size, 1, 1) threads per threadgroup -// -// src layout (matches GGML_OP_GATED_DELTA_NET): -// src[0]=q, src[1]=k, src[2]=v, src[3]=gate, src[4]=beta, src[5]=state -// Dimensions from v: S_v=ne[0], H=ne[1], n_tokens=ne[2], n_seqs=ne[3] -// gate: [1,H,T,B] (GDA) or [S_v,H,T,B] (KDA) -kernel void kernel_gated_delta_net_f32( - device const float * q, - device const float * k, - device const float * v, - device const float * gate, // [G, H, T, B] log-space; G=1 (GDA) or G=S (KDA) - device const float * beta, // [1, H, T, B] - device const float * state_in, // [S*S*H, n_seqs] row-major per head - device float * dst, // [S*H, n_tokens*n_seqs + S*n_seqs] - constant uint & B, // n_seqs - constant uint & T, // n_tokens * n_seqs (total tokens) - constant uint & C, // S * H - constant uint & H, // num heads - constant uint & G, // gate ne[0]: 1=GDA, S=KDA - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { +constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]]; +constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]]; - const uint head_size = C / H; - const uint batch_id = tgpig.x / H; - const uint head_id = tgpig.x % H; - const uint tid = tpitg.x; +template +kernel void kernel_gated_delta_net_impl( + constant ggml_metal_kargs_gated_delta_net & args, + device const char * q, + device const char * k, + device const char * v, + device const char * g, + device const char * b, + device const char * s, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { +#define S_v FC_gated_delta_net_ne20 +#define G FC_gated_delta_net_ne30 - if (batch_id >= B || head_id >= H || tid >= head_size) { - return; - } + const uint tx = tpitg.x; + const uint ty = tpitg.y; - const float scale = 1.0f / sqrt((float)head_size); + const uint i23 = tgpig.z; // B + const uint i21 = tgpig.y; // H + const uint i20 = tgpig.x*NSG + ty; - const uint state_size = C * head_size; // S * S * H - const uint n_seq_tokens = T / B; + const uint i01 = i21 % args.ne01; + const uint i11 = i21 % args.ne11; + + const float scale = 1.0f / sqrt((float)S_v); - // Max head_size is 128 (enforced by pipeline getter). - threadgroup float _k[128]; - threadgroup float _q[128]; - threadgroup float _g[128]; // gate vector for KDA mode + device const float * s_ptr = (device const float *) (s) + (i23 * args.ne21*S_v*S_v + i21*S_v*S_v); - // Load initial state: thread tid owns column tid of M (row-major) - // M[row=j][col=tid] at offset j*head_size+tid (row-major, coalesced) - float state[128]; - for (uint j = 0; j < head_size; j++) { - state[j] = state_in[batch_id * state_size + head_id * head_size * head_size - + j * head_size + tid]; + float ls[NSG]; + + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + ls[j] = s_ptr[is*S_v + i20]; } - // Process tokens sequentially - for (uint tt = 0; tt < n_seq_tokens; tt++) { - const uint t_abs = batch_id * n_seq_tokens + tt; - const uint kv_offset = t_abs * C + head_id * head_size; - const uint gb_offset = t_abs * H + head_id; + device float * dst_attn = (device float *) (dst) + i23*args.ne22*args.ne21*S_v + i21*S_v + i20; - // Load k, q (and gate for KDA) into shared memory - threadgroup_barrier(mem_flags::mem_threadgroup); - _k[tid] = k[kv_offset + tid]; - _q[tid] = q[kv_offset + tid]; - if (G > 1) { - _g[tid] = exp(min(gate[t_abs * H * head_size + head_id * head_size + tid], 88.0f)); - } - threadgroup_barrier(mem_flags::mem_threadgroup); + device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01); + device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11); + device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21); + + device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21); + device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G; + + FOR_UNROLL (short t = 0; t < args.ne22; t++) { + const float beta = b_ptr[0]; - const float beta_val = beta[gb_offset]; - const float v_tid = v[kv_offset + tid]; + float s_k = 0.0f; - // Decay state and compute sk = sum_j M[j][tid] * k[j] - // Two-pass approach: decay state first, then compute on decayed state. - // Algebraically equivalent to CUDA's fused form: delta = (v - g*dot(S,k))*beta; - // S = g*S + k*delta, since dot(g*S, k) = g*dot(S, k). - // GDA: all elements scaled by same exp(g) - // KDA: state[j] = M[j][tid] scaled by exp(g[j]) (per-row gate) - float sk = 0.0f; if (G == 1) { - const float g_exp = exp(min(gate[gb_offset], 88.0f)); - for (uint j = 0; j < head_size; j += 4) { - float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]); - float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]); - s_vec *= g_exp; - sk += dot(s_vec, k_vec); - state[j] = s_vec[0]; - state[j+1] = s_vec[1]; - state[j+2] = s_vec[2]; - state[j+3] = s_vec[3]; + const float g_exp = exp(g_ptr[0]); + + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + ls[j] *= g_exp; + + s_k += ls[j]*k_ptr[is]; } } else { - for (uint j = 0; j < head_size; j += 4) { - float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]); - float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]); - float4 g_vec = float4(_g[j], _g[j+1], _g[j+2], _g[j+3]); - s_vec *= g_vec; - sk += dot(s_vec, k_vec); - state[j] = s_vec[0]; - state[j+1] = s_vec[1]; - state[j+2] = s_vec[2]; - state[j+3] = s_vec[3]; + // KDA + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + ls[j] *= exp(g_ptr[is]); + + s_k += ls[j]*k_ptr[is]; } } - // Delta: d = (v[tid] - sk) * beta - // Note: delta is per-column (thread-local), matching CPU's delta[j] = (v[j] - sk[j]) * beta - // Here sk is sum_j M[j][tid]*k[j], and v_tid = v[tid], so d = delta for column tid - float d = (v_tid - sk) * beta_val; + s_k = simd_sum(s_k); - // State update: M[j][tid] += k[j] * d (rank-1 outer product k * delta^T) - for (uint j = 0; j < head_size; j += 4) { - float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]); - float4 k_vec = float4(_k[j], _k[j+1], _k[j+2], _k[j+3]); - s_vec += k_vec * d; - state[j] = s_vec[0]; - state[j+1] = s_vec[1]; - state[j+2] = s_vec[2]; - state[j+3] = s_vec[3]; - } + const float d = (v_ptr[i20] - s_k)*beta; - // Output: o[tid] = sum_j M[j][tid] * q[j] * scale float y = 0.0f; - for (uint j = 0; j < head_size; j += 4) { - float4 s_vec = float4(state[j], state[j+1], state[j+2], state[j+3]); - float4 q_vec = float4(_q[j], _q[j+1], _q[j+2], _q[j+3]); - y += dot(s_vec, q_vec); + + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + ls[j] += k_ptr[is]*d; + + y += ls[j]*q_ptr[is]; + } + + y = simd_sum(y); + + if (tx == 0) { + dst_attn[t*args.ne21*S_v] = y*scale; } - dst[t_abs * C + head_id * head_size + tid] = y * scale; + + q_ptr += args.ns02; + k_ptr += args.ns12; + v_ptr += args.ns22; + + b_ptr += args.ne21; + g_ptr += args.ne21*G; } - // Write final state (row-major: M[j][tid] at j*head_size+tid) - for (uint j = 0; j < head_size; j++) { - dst[T * C + batch_id * state_size + head_id * head_size * head_size - + j * head_size + tid] = state[j]; + device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20; + + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + dst_state[is*S_v] = ls[j]; } + +#undef S_v +#undef G } +typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t; + +template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<1>; +template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<2>; +template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<4>; + constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]]; constant short FC_solve_tri_n [[function_constant(FC_SOLVE_TRI + 1)]]; constant short FC_solve_tri_k [[function_constant(FC_SOLVE_TRI + 2)]]; From 4ac4e0bd9a60f1ce8eadd56b8baccd5692b538ec Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 11 Mar 2026 12:42:55 +0200 Subject: [PATCH 5/5] cont : clean-up --- ggml/src/ggml-metal/ggml-metal.metal | 118 +++++++++++++++++++++++++-- 1 file changed, 111 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 987b230aedd..6d2f41909ae 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2437,6 +2437,7 @@ kernel void kernel_rwkv_wkv7_f32( constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]]; constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]]; +#if 1 template kernel void kernel_gated_delta_net_impl( constant ggml_metal_kargs_gated_delta_net & args, @@ -2465,16 +2466,16 @@ kernel void kernel_gated_delta_net_impl( const float scale = 1.0f / sqrt((float)S_v); - device const float * s_ptr = (device const float *) (s) + (i23 * args.ne21*S_v*S_v + i21*S_v*S_v); + device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20; float ls[NSG]; FOR_UNROLL (short j = 0; j < NSG; j++) { const short is = tx*NSG + j; - ls[j] = s_ptr[is*S_v + i20]; + ls[j] = s_ptr[is*S_v]; } - device float * dst_attn = (device float *) (dst) + i23*args.ne22*args.ne21*S_v + i21*S_v + i20; + device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20; device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01); device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11); @@ -2483,9 +2484,7 @@ kernel void kernel_gated_delta_net_impl( device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21); device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G; - FOR_UNROLL (short t = 0; t < args.ne22; t++) { - const float beta = b_ptr[0]; - + for (short t = 0; t < args.ne22; t++) { float s_k = 0.0f; if (G == 1) { @@ -2509,7 +2508,7 @@ kernel void kernel_gated_delta_net_impl( s_k = simd_sum(s_k); - const float d = (v_ptr[i20] - s_k)*beta; + const float d = (v_ptr[i20] - s_k)*b_ptr[0]; float y = 0.0f; @@ -2551,6 +2550,111 @@ template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<2>; template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<4>; +#else +// a simplified version of the above +// no performance improvement, so keep the above version for now + +template +kernel void kernel_gated_delta_net_impl( + constant ggml_metal_kargs_gated_delta_net & args, + device const char * q, + device const char * k, + device const char * v, + device const char * g, + device const char * b, + device const char * s, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { +#define S_v FC_gated_delta_net_ne20 +#define G FC_gated_delta_net_ne30 + + const uint tx = tpitg.x; + const uint ty = tpitg.y; + + const uint i23 = tgpig.z; // B + const uint i21 = tgpig.y; // H + const uint i20 = tgpig.x*NSG + ty; + + const uint i01 = i21 % args.ne01; + const uint i11 = i21 % args.ne11; + + const float scale = 1.0f / sqrt((float)S_v); + + device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20; + + float lsf[NSG]; + + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + lsf[j] = s_ptr[is*S_v]; + } + + thread T * ls = (thread T *) (lsf); + + device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20; + + device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01); + device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11); + device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21); + + device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21); + device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G; + + for (short t = 0; t < args.ne22; t++) { + device const T * qt_ptr = (device const T *) (q_ptr); + device const T * kt_ptr = (device const T *) (k_ptr); + device const T * gt_ptr = (device const T *) (g_ptr); + + if (G == 1) { + *ls *= exp(g_ptr[0]); + } else { + // KDA + *ls *= exp(gt_ptr[tx]); + } + + const float s_k = simd_sum(dot(*ls, kt_ptr[tx])); + + const float d = (v_ptr[i20] - s_k)*b_ptr[0]; + + *ls += kt_ptr[tx]*d; + + const float y = simd_sum(dot(*ls, qt_ptr[tx])); + + if (tx == 0) { + *dst_attn = y*scale; + } + + q_ptr += args.ns02; + k_ptr += args.ns12; + v_ptr += args.ns22; + + b_ptr += args.ne21; + g_ptr += args.ne21*G; + + dst_attn += args.ne21*S_v; + } + + device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20; + device T * dstt_state = (device T *) (dst_state); + + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + dst_state[is*S_v] = lsf[j]; + } + +#undef S_v +#undef G +} + +typedef decltype(kernel_gated_delta_net_impl) kernel_gated_delta_net_t; + +template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl; +template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl; +template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl; +#endif + constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]]; constant short FC_solve_tri_n [[function_constant(FC_SOLVE_TRI + 1)]]; constant short FC_solve_tri_k [[function_constant(FC_SOLVE_TRI + 2)]];