diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 61d112c50a7..9541a815693 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -813,6 +813,12 @@ struct vk_device_struct { vk_pipeline pipeline_pool2d_f32; vk_pipeline pipeline_rwkv_wkv6_f32; vk_pipeline pipeline_rwkv_wkv7_f32; + vk_pipeline pipeline_gated_delta_net_f32_d32; + vk_pipeline pipeline_gated_delta_net_f32_d64; + vk_pipeline pipeline_gated_delta_net_f32_d128; + vk_pipeline pipeline_gated_delta_net_f32_d32_kda; + vk_pipeline pipeline_gated_delta_net_f32_d64_kda; + vk_pipeline pipeline_gated_delta_net_f32_d128_kda; vk_pipeline pipeline_ssm_scan_f32_d128; vk_pipeline pipeline_ssm_scan_f32_d256; vk_pipeline pipeline_ssm_conv_f32; @@ -1439,6 +1445,17 @@ struct vk_op_rwkv_wkv7_push_constants { uint32_t C; uint32_t H; }; +struct vk_op_gated_delta_net_push_constants { + uint32_t H; + uint32_t n_tokens; + uint32_t n_seqs; + uint32_t s_off; + uint32_t sq1, sq2, sq3; + uint32_t sv1, sv2, sv3; + uint32_t sb1, sb2, sb3; + uint32_t rq1, rq3; +}; + struct vk_op_ssm_scan_push_constants { uint32_t nb02, nb03, nb12, nb13; uint32_t nb21, nb22, nb31; @@ -4559,6 +4576,13 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_f32_d32, "gated_delta_net_f32_d32", gated_delta_net_f32_len, gated_delta_net_f32_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants), {1, 1, 1}, {32, 0}, 1); + ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_f32_d64, "gated_delta_net_f32_d64", gated_delta_net_f32_len, gated_delta_net_f32_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants), {1, 1, 1}, {64, 0}, 1); + ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_f32_d128, "gated_delta_net_f32_d128", gated_delta_net_f32_len, gated_delta_net_f32_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants), {1, 1, 1}, {128, 0}, 1); + ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_f32_d32_kda, "gated_delta_net_f32_d32_kda", gated_delta_net_f32_len, gated_delta_net_f32_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants), {1, 1, 1}, {32, 1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_f32_d64_kda, "gated_delta_net_f32_d64_kda", gated_delta_net_f32_len, gated_delta_net_f32_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants), {1, 1, 1}, {64, 1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_f32_d128_kda, "gated_delta_net_f32_d128_kda", gated_delta_net_f32_len, gated_delta_net_f32_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants), {1, 1, 1}, {128, 1}, 1); + if (device->subgroup_arithmetic && device->subgroup_require_full_support) { ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true); ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true); @@ -9478,6 +9502,25 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_rwkv_wkv7_f32; } return nullptr; + case GGML_OP_GATED_DELTA_NET: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + const uint32_t S_v = dst->src[2]->ne[0]; + const bool kda = (dst->src[3]->ne[0] == (int64_t)S_v); + if (kda) { + switch (S_v) { + case 32: return ctx->device->pipeline_gated_delta_net_f32_d32_kda; + case 64: return ctx->device->pipeline_gated_delta_net_f32_d64_kda; + case 128: return ctx->device->pipeline_gated_delta_net_f32_d128_kda; + } + } else { + switch (S_v) { + case 32: return ctx->device->pipeline_gated_delta_net_f32_d32; + case 64: return ctx->device->pipeline_gated_delta_net_f32_d64; + case 128: return ctx->device->pipeline_gated_delta_net_f32_d128; + } + } + } + return nullptr; case GGML_OP_SSM_SCAN: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { const uint32_t d_state = src0->ne[0]; @@ -10308,6 +10351,67 @@ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ); } +static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { + const ggml_tensor * src_q = dst->src[0]; + const ggml_tensor * src_k = dst->src[1]; + const ggml_tensor * src_v = dst->src[2]; + const ggml_tensor * src_g = dst->src[3]; + const ggml_tensor * src_beta = dst->src[4]; + const ggml_tensor * src_state = dst->src[5]; + + GGML_ASSERT(dst->buffer != nullptr); + + const uint32_t S_v = (uint32_t)src_v->ne[0]; + const uint32_t H = (uint32_t)src_v->ne[1]; + const uint32_t n_tokens = (uint32_t)src_v->ne[2]; + const uint32_t n_seqs = (uint32_t)src_v->ne[3]; + + const uint32_t s_off = S_v * H * n_tokens * n_seqs; + + for (int i = 0; i < 6; i++) { + GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type)); + } + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op); + GGML_ASSERT(pipeline != nullptr); + + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst); + vk_subbuffer src_buf[6] = {}; + for (int i = 0; i < 6; i++) { + src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]); + } + + // q and k share strides + const uint32_t sq1 = (uint32_t)(src_q->nb[1] / sizeof(float)); + const uint32_t sq2 = (uint32_t)(src_q->nb[2] / sizeof(float)); + const uint32_t sq3 = (uint32_t)(src_q->nb[3] / sizeof(float)); + const uint32_t sv1 = (uint32_t)(src_v->nb[1] / sizeof(float)); + const uint32_t sv2 = (uint32_t)(src_v->nb[2] / sizeof(float)); + const uint32_t sv3 = (uint32_t)(src_v->nb[3] / sizeof(float)); + const uint32_t sb1 = (uint32_t)(src_beta->nb[1] / sizeof(float)); + const uint32_t sb2 = (uint32_t)(src_beta->nb[2] / sizeof(float)); + const uint32_t sb3 = (uint32_t)(src_beta->nb[3] / sizeof(float)); + + const uint32_t rq1 = (uint32_t)(src_v->ne[1] / src_q->ne[1]); + const uint32_t rq3 = (uint32_t)(src_v->ne[3] / src_q->ne[3]); + + const vk_op_gated_delta_net_push_constants pc = { + H, n_tokens, n_seqs, s_off, + sq1, sq2, sq3, + sv1, sv2, sv3, + sb1, sb2, sb3, + rq1, rq3 + }; + + std::array elements = { H, n_seqs, 1 }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + {src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf}, + pc, elements); +} + static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; @@ -13024,6 +13128,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; + case GGML_OP_GATED_DELTA_NET: + ggml_vk_gated_delta_net(ctx, compute_ctx, node); + + break; + case GGML_OP_SSM_SCAN: ggml_vk_ssm_scan(ctx, compute_ctx, node); @@ -15426,6 +15535,19 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: return true; // all inputs are contiguous, see ggml.c + case GGML_OP_GATED_DELTA_NET: + { + const uint32_t S_v = op->src[2]->ne[0]; + if (S_v != 32 && S_v != 64 && S_v != 128) { + return false; + } + for (int i = 0; i < 6; i++) { + if (op->src[i] && ggml_is_quantized(op->src[i]->type)) { + return false; + } + } + return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + } case GGML_OP_SSM_SCAN: { for (int i = 0; i < 6; i++) { @@ -16299,6 +16421,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * } else if (tensor->op == GGML_OP_RWKV_WKV7) { tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], src_clone[4], src_clone[5], src_clone[6]); + } else if (tensor->op == GGML_OP_GATED_DELTA_NET) { + tensor_clone = ggml_gated_delta_net(ggml_ctx, src_clone[0], src_clone[1], + src_clone[2], src_clone[3], src_clone[4], src_clone[5]); } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) { src_clone[0]->flags = tensor->src[0]->flags; tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1], diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp new file mode 100644 index 00000000000..61676f65576 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp @@ -0,0 +1,122 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require + +layout(constant_id = 0) const uint S_V = 128; +layout(constant_id = 1) const uint KDA = 0; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout(push_constant) uniform Parameters { + uint H; + uint n_tokens; + uint n_seqs; + uint s_off; + uint sq1, sq2, sq3; + uint sv1, sv2, sv3; + uint sb1, sb2, sb3; + uint rq1, rq3; +}; + +layout(binding = 0) readonly buffer QBuf { float q_in[]; }; +layout(binding = 1) readonly buffer KBuf { float k_in[]; }; +layout(binding = 2) readonly buffer VBuf { float v_in[]; }; +layout(binding = 3) readonly buffer GBuf { float g_in[]; }; +layout(binding = 4) readonly buffer BetaBuf { float beta_in[]; }; +layout(binding = 5) readonly buffer StateBuf { float state_in[]; }; +layout(binding = 6) buffer DstBuf { float dst[]; }; + +shared float s_k[S_V]; +shared float s_q[S_V]; + +void main() { + const uint head_id = gl_WorkGroupID.x; + const uint seq_id = gl_WorkGroupID.y; + const uint col = gl_LocalInvocationID.x; + + if (col >= S_V) { + return; + } + + const uint iq1 = head_id / rq1; + const uint iq3 = seq_id / rq3; + + const uint state_size = S_V * S_V; + const uint state_base = (seq_id * H + head_id) * state_size; + + const float scale = 1.0 / sqrt(float(S_V)); + + // Load state column into registers: S[i][col] for all rows i + float state[S_V]; + [[unroll]] for (uint i = 0; i < S_V; i++) { + state[i] = state_in[state_base + i * S_V + col]; + } + + uint attn_off = (seq_id * n_tokens * H + head_id) * S_V; + + for (uint t = 0; t < n_tokens; t++) { + // Load q and k into shared memory + const uint q_off = iq3 * sq3 + t * sq2 + iq1 * sq1; + const uint k_off = q_off; // q and k share strides (asserted in ggml.c) + const uint v_off = seq_id * sv3 + t * sv2 + head_id * sv1; + + s_q[col] = q_in[q_off + col]; + s_k[col] = k_in[k_off + col]; + barrier(); + + const float v_val = v_in[v_off + col]; + + // beta and gate offsets + const uint gb_off = seq_id * sb3 + t * sb2 + head_id * sb1; + const float beta_val = beta_in[gb_off]; + + if (KDA == 0) { + const float g_val = exp(g_in[gb_off]); + + // kv_col = (S^T @ k)[col] = sum_i S[i][col] * k[i] + float kv_col = 0.0; + [[unroll]] for (uint i = 0; i < S_V; i++) { + kv_col += state[i] * s_k[i]; + } + + // delta_col = (v[col] - g * kv_col) * beta + float delta_col = (v_val - g_val * kv_col) * beta_val; + + // Fused decay + rank-1 update + readout + float attn_col = 0.0; + [[unroll]] for (uint i = 0; i < S_V; i++) { + state[i] = g_val * state[i] + s_k[i] * delta_col; + attn_col += state[i] * s_q[i]; + } + + dst[attn_off + col] = attn_col * scale; + } else { + // KDA: per-row vector gate + const uint g_base = gb_off * S_V; + + // kv_col = sum_i exp(g[i]) * S[i][col] * k[i] + float kv_col = 0.0; + [[unroll]] for (uint i = 0; i < S_V; i++) { + kv_col += exp(g_in[g_base + i]) * state[i] * s_k[i]; + } + + float delta_col = (v_val - kv_col) * beta_val; + + float attn_col = 0.0; + [[unroll]] for (uint i = 0; i < S_V; i++) { + state[i] = exp(g_in[g_base + i]) * state[i] + s_k[i] * delta_col; + attn_col += state[i] * s_q[i]; + } + + dst[attn_off + col] = attn_col * scale; + } + + attn_off += S_V * H; + barrier(); + } + + // Write final state + [[unroll]] for (uint i = 0; i < S_V; i++) { + dst[s_off + state_base + i * S_V + col] = state[i]; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index fb8941232bc..fe0fc4bea6e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -987,6 +987,8 @@ void process_shaders() { string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));