From f8742f86cf4062bb12bccb3d63e5a143f1c6551a Mon Sep 17 00:00:00 2001 From: Anav Prasad Date: Tue, 28 Apr 2026 12:26:13 +0000 Subject: [PATCH 1/4] CUDA: ssm_conv + bias + silu fusion --- ggml/src/ggml-cuda/ggml-cuda.cu | 29 ++++++++++++- ggml/src/ggml-cuda/ssm-conv.cu | 47 ++++++++++++++++----- ggml/src/ggml-cuda/ssm-conv.cuh | 2 +- tests/test-backend-ops.cpp | 73 +++++++++++++++++++++++++++++++++ 4 files changed, 138 insertions(+), 13 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index fd8dd91714ca..8d92fc5b5699 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3564,6 +3564,28 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, return true; } + if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SSM_CONV && ops.begin()[1] == GGML_OP_ADD + && ops.begin()[2] == GGML_OP_UNARY && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) { + const ggml_tensor * ssm_conv = cgraph->nodes[node_idx]; + const ggml_tensor * add = cgraph->nodes[node_idx+1]; + const ggml_tensor * silu = cgraph->nodes[node_idx+2]; + + if (ssm_conv->type != GGML_TYPE_F32 || add->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) { + return false; + } + + // ADD must consume ssm_conv's output and broadcast a 1-D channel-wise bias. + const ggml_tensor * bias = (add->src[0] == ssm_conv) ? add->src[1] : add->src[0]; + if (bias->type != GGML_TYPE_F32 || !ggml_is_contiguous(bias)) { + return false; + } + if (ggml_nelements(bias) != ssm_conv->ne[0] || bias->ne[0] != ssm_conv->ne[0]) { + return false; + } + + return true; + } + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_MUL && unary_ops.size() == 1 && (unary_ops.begin()[0] == GGML_UNARY_OP_SILU || unary_ops.begin()[0] == GGML_UNARY_OP_SIGMOID || unary_ops.begin()[0] == GGML_UNARY_OP_SOFTPLUS)) { const ggml_tensor * unary = cgraph->nodes[node_idx]; @@ -3966,8 +3988,13 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph return 1; } + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_ADD, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) { + ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i + 1], cgraph->nodes[i + 2]); + return 2; + } + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) { - ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i + 1]); + ggml_cuda_op_ssm_conv(*cuda_ctx, node, /*bias_add_node=*/ nullptr, cgraph->nodes[i + 1]); return 1; } diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index b77cdc1c1376..211fcb12440f 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -1,8 +1,9 @@ #include "ssm-conv.cuh" #include "unary.cuh" -template +template static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1, + const float * __restrict__ bias, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int64_t n_t) { @@ -27,6 +28,8 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float w[j] = w_block[tid * stride_w + j]; } + float b = apply_bias ? bias[bidy * split_d_inner + tid] : 0.0f; + for (int64_t i = 0; i < n_t; i++) { float sumf = 0.0f; @@ -42,12 +45,14 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float for (size_t j = 0; j < d_conv; j++) { sumf += x[(i + j) % d_conv] * w[j]; } + sumf = apply_bias ? sumf + b : sumf; y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf; } } -template +template static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1, + const float * __restrict__ bias, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int64_t n_t) { @@ -97,6 +102,8 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, w[j] = w_block[tid * stride_w + j]; } + float b = apply_bias ? bias[bidy * split_d_inner + tid] : 0.0f; + // Compute from shared memory for (int64_t i = 0; i < local_n_t; i++) { float sumf = 0.0f; @@ -104,12 +111,13 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, for (size_t j = 0; j < d_conv; j++) { sumf += smem[tid * n_cols + i + j] * w[j]; } + sumf = apply_bias ? sumf + b : sumf; y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf; } } -template -static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1, +template +static void ssm_conv_f32_cuda(const float * src0, const float * src1, const float * bias, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t, const int64_t n_s, cudaStream_t stream) { @@ -120,14 +128,14 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int constexpr int kNC = decltype(NC)::value; if (n_t <= 32) { const dim3 blocks(n_s, (nr + threads - 1) / threads, 1); - ssm_conv_f32<<>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, + ssm_conv_f32<<>>(src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); } else { const int64_t split_n_t = 32; dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); const size_t smem_size = threads * (kNC - 1 + split_n_t) * sizeof(float); - ssm_conv_long_token_f32<<>>( - src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); + ssm_conv_long_token_f32<<>>( + src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); } }; @@ -140,11 +148,18 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int } } -void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst) { +void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * bias_add_node, ggml_tensor * silu_dst) { const struct ggml_tensor * src0 = dst->src[0]; // conv_x const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight + const bool fuse_bias = bias_add_node != nullptr; const bool fuse_silu = silu_dst != nullptr; + // bias always comes with silu. + GGML_ASSERT(!fuse_bias || fuse_silu); + + // The bias (when fused) is the non-conv operand of the ADD node. + const struct ggml_tensor * bias = fuse_bias ? (bias_add_node->src[0] == dst ? bias_add_node->src[1] : bias_add_node->src[0]) : nullptr; + // When fusing, write to silu_dst (the node downstream references). const struct ggml_tensor * out = fuse_silu ? silu_dst : dst; @@ -160,16 +175,26 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, g const float * src0_d = (const float *) src0->data; const float * src1_d = (const float *) src1->data; + const float * bias_d = fuse_bias ? (const float *) bias->data : nullptr; float * dst_d = (float *) out->data; cudaStream_t stream = ctx.stream(); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(out->type == GGML_TYPE_F32); - if (fuse_silu) { - ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], + if (fuse_bias) { + GGML_ASSERT(bias->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(bias)); + GGML_ASSERT(ggml_nelements(bias) == nr); + } + + if (fuse_bias) { + ssm_conv_f32_cuda(src0_d, src1_d, bias_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], + out->nb[2], nc, nr, n_t, n_s, stream); + } else if (fuse_silu) { + ssm_conv_f32_cuda(src0_d, src1_d, nullptr, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], out->nb[2], nc, nr, n_t, n_s, stream); } else { - ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], + ssm_conv_f32_cuda(src0_d, src1_d, nullptr, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], out->nb[2], nc, nr, n_t, n_s, stream); } } diff --git a/ggml/src/ggml-cuda/ssm-conv.cuh b/ggml/src/ggml-cuda/ssm-conv.cuh index f96a1cd2484b..8514ca84920f 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cuh +++ b/ggml/src/ggml-cuda/ssm-conv.cuh @@ -1,3 +1,3 @@ #include "common.cuh" -void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst = nullptr); +void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * bias_add_node = nullptr, ggml_tensor * silu_dst = nullptr); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 941c20ce19d1..b7e56b6eb788 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3579,6 +3579,55 @@ struct test_ssm_conv : public test_case { } }; +// GGML_OP_SSM_CONV + GGML_OP_ADD (channel-wise bias) + GGML_OP_UNARY(SILU) (fused operation) +struct test_ssm_conv_bias_silu : public test_case { + const ggml_type type; + const std::array ne_a; + const std::array ne_b; + const bool fuse_bias; + const bool fuse_silu; + + std::string op_desc(ggml_tensor * t) override { + GGML_UNUSED(t); + return "SSM_CONV_BIAS_SILU"; + } + + bool run_whole_graph() override { return true; } + + std::string vars() override { + return VARS_TO_STR5(type, ne_a, ne_b, fuse_bias, fuse_silu); + } + + test_ssm_conv_bias_silu(ggml_type type = GGML_TYPE_F32, + std::array ne_a = {4, 128, 1, 1}, + std::array ne_b = {4, 128, 1, 1}, + bool fuse_bias = true, + bool fuse_silu = true) + : type(type), ne_a(ne_a), ne_b(ne_b), fuse_bias(fuse_bias), fuse_silu(fuse_silu) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); + ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data()); + ggml_set_name(a, "a"); + ggml_set_name(b, "b"); + + ggml_tensor * out = ggml_ssm_conv(ctx, a, b); + + if (fuse_bias) { + ggml_tensor * bias = ggml_new_tensor_1d(ctx, type, out->ne[0]); + ggml_set_name(bias, "bias"); + out = ggml_add(ctx, out, bias); + } + + if (fuse_silu) { + out = ggml_silu(ctx, out); + } + + ggml_set_name(out, "out"); + return out; + } +}; + // GGML_OP_SSM_SCAN struct test_ssm_scan : public test_case { const ggml_type type; @@ -7977,6 +8026,28 @@ static std::vector> make_test_cases_eval() { } } + // fused ssm_conv + (optional) bias_add + silu. The bias-only graph (no silu) is intentionally + // not tested since there's no fusion for that pattern in ggml_cuda_can_fuse. + for (int64_t d_conv : {3, 4, 9}) { + for (int64_t d_inner : {1024, 1536, 2048}) { + for (bool fuse_bias : {false, true}) { + const bool fuse_silu = true; + // short token path (n_t <= 32) + test_cases.emplace_back(new test_ssm_conv_bias_silu( + GGML_TYPE_F32, {d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}, fuse_bias, fuse_silu)); + test_cases.emplace_back(new test_ssm_conv_bias_silu( + GGML_TYPE_F32, {2 * d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}, fuse_bias, fuse_silu)); + test_cases.emplace_back(new test_ssm_conv_bias_silu( + GGML_TYPE_F32, {d_conv, d_inner, 4, 1}, {d_conv, d_inner, 1, 1}, fuse_bias, fuse_silu)); + // long token path (n_t > 32) + test_cases.emplace_back(new test_ssm_conv_bias_silu( + GGML_TYPE_F32, {d_conv - 1 + 64, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}, fuse_bias, fuse_silu)); + test_cases.emplace_back(new test_ssm_conv_bias_silu( + GGML_TYPE_F32, {d_conv - 1 + 64, d_inner, 4, 1}, {d_conv, d_inner, 1, 1}, fuse_bias, fuse_silu)); + } + } + } + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1 test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 16, 2, 32, 4)); // Mamba-2 test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 256, 64, 8, 2, 32, 4)); // Falcon-H1 @@ -8993,6 +9064,8 @@ static std::vector> make_test_cases_perf() { // Examples from granite-4.0-h-1b/ggml-model-Q8_0.gguf test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {515, 3328, 1, 1}, {4, 3328, 1, 1})); // prefill test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 3328, 1, 1}, {4, 3328, 1, 1})); // generate + test_cases.emplace_back(new test_ssm_conv_bias_silu(GGML_TYPE_F32, {515, 3328, 1, 1}, {4, 3328, 1, 1}, true, true)); // prefill + test_cases.emplace_back(new test_ssm_conv_bias_silu(GGML_TYPE_F32, {4, 3328, 1, 1}, {4, 3328, 1, 1}, true, true)); // generate test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 512, 1)); // prefill test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 1, 1)); // generate From aa960f73527c56dffe249520dc5ab1f89545ed9e Mon Sep 17 00:00:00 2001 From: Anav Prasad Date: Wed, 29 Apr 2026 08:08:37 +0000 Subject: [PATCH 2/4] Apply suggestions from code review Co-authored-by: Gaurav Garg --- ggml/src/ggml-cuda/ggml-cuda.cu | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 8d92fc5b5699..f8b2272d24a1 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3555,7 +3555,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, if (ops.size() == 2 && ops.begin()[0] == GGML_OP_SSM_CONV && ops.begin()[1] == GGML_OP_UNARY && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) { const ggml_tensor * ssm_conv = cgraph->nodes[node_idx]; - const ggml_tensor * silu = cgraph->nodes[node_idx+1]; + if (ggml_get_unary_op(silu) != unary_ops.begin()[0]) { + return false; + } if (ssm_conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) { return false; @@ -3569,6 +3571,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, const ggml_tensor * ssm_conv = cgraph->nodes[node_idx]; const ggml_tensor * add = cgraph->nodes[node_idx+1]; const ggml_tensor * silu = cgraph->nodes[node_idx+2]; + if (ggml_get_unary_op(silu) != unary_ops.begin()[0]) { + return false; + } if (ssm_conv->type != GGML_TYPE_F32 || add->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) { return false; From 3a7085f74540f10041fececd6b6d15ed76971e08 Mon Sep 17 00:00:00 2001 From: Anav Prasad Date: Wed, 29 Apr 2026 08:10:51 +0000 Subject: [PATCH 3/4] adding back accidentally deleted line --- ggml/src/ggml-cuda/ggml-cuda.cu | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index f8b2272d24a1..0e6f74685d67 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3555,9 +3555,10 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, if (ops.size() == 2 && ops.begin()[0] == GGML_OP_SSM_CONV && ops.begin()[1] == GGML_OP_UNARY && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) { const ggml_tensor * ssm_conv = cgraph->nodes[node_idx]; - if (ggml_get_unary_op(silu) != unary_ops.begin()[0]) { - return false; - } + const ggml_tensor * silu = cgraph->nodes[node_idx+1]; + if (ggml_get_unary_op(silu) != unary_ops.begin()[0]) { + return false; + } if (ssm_conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) { return false; From 373ca030356b2ae036728e4e218fb4405162d42d Mon Sep 17 00:00:00 2001 From: Anav Prasad Date: Wed, 29 Apr 2026 10:21:39 +0000 Subject: [PATCH 4/4] simplify ssm_conv fusion templating and test struct --- ggml/src/ggml-cuda/ssm-conv.cu | 27 ++++++++++++--------------- tests/test-backend-ops.cpp | 33 +++++++++++++-------------------- 2 files changed, 25 insertions(+), 35 deletions(-) diff --git a/ggml/src/ggml-cuda/ssm-conv.cu b/ggml/src/ggml-cuda/ssm-conv.cu index 211fcb12440f..4841389fbc88 100644 --- a/ggml/src/ggml-cuda/ssm-conv.cu +++ b/ggml/src/ggml-cuda/ssm-conv.cu @@ -1,7 +1,7 @@ #include "ssm-conv.cuh" #include "unary.cuh" -template +template static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ bias, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, @@ -28,7 +28,7 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float w[j] = w_block[tid * stride_w + j]; } - float b = apply_bias ? bias[bidy * split_d_inner + tid] : 0.0f; + float b = bias != nullptr ? bias[bidy * split_d_inner + tid] : 0.0f; for (int64_t i = 0; i < n_t; i++) { float sumf = 0.0f; @@ -45,12 +45,12 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float for (size_t j = 0; j < d_conv; j++) { sumf += x[(i + j) % d_conv] * w[j]; } - sumf = apply_bias ? sumf + b : sumf; + sumf += b; y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf; } } -template +template static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ bias, const int src0_nb0, const int src0_nb1, const int src0_nb2, @@ -102,7 +102,7 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, w[j] = w_block[tid * stride_w + j]; } - float b = apply_bias ? bias[bidy * split_d_inner + tid] : 0.0f; + float b = bias != nullptr ? bias[bidy * split_d_inner + tid] : 0.0f; // Compute from shared memory for (int64_t i = 0; i < local_n_t; i++) { @@ -111,12 +111,12 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, for (size_t j = 0; j < d_conv; j++) { sumf += smem[tid * n_cols + i + j] * w[j]; } - sumf = apply_bias ? sumf + b : sumf; + sumf += b; y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf; } } -template +template static void ssm_conv_f32_cuda(const float * src0, const float * src1, const float * bias, const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t, @@ -128,13 +128,13 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const floa constexpr int kNC = decltype(NC)::value; if (n_t <= 32) { const dim3 blocks(n_s, (nr + threads - 1) / threads, 1); - ssm_conv_f32<<>>(src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1, + ssm_conv_f32<<>>(src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); } else { const int64_t split_n_t = 32; dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); const size_t smem_size = threads * (kNC - 1 + split_n_t) * sizeof(float); - ssm_conv_long_token_f32<<>>( + ssm_conv_long_token_f32<<>>( src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); } }; @@ -187,14 +187,11 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, g GGML_ASSERT(ggml_nelements(bias) == nr); } - if (fuse_bias) { - ssm_conv_f32_cuda(src0_d, src1_d, bias_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], - out->nb[2], nc, nr, n_t, n_s, stream); - } else if (fuse_silu) { - ssm_conv_f32_cuda(src0_d, src1_d, nullptr, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], + if (fuse_silu) { + ssm_conv_f32_cuda(src0_d, src1_d, bias_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], out->nb[2], nc, nr, n_t, n_s, stream); } else { - ssm_conv_f32_cuda(src0_d, src1_d, nullptr, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], + ssm_conv_f32_cuda(src0_d, src1_d, bias_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], out->nb[2], nc, nr, n_t, n_s, stream); } } diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b7e56b6eb788..be2a6ef5641d 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3579,13 +3579,12 @@ struct test_ssm_conv : public test_case { } }; -// GGML_OP_SSM_CONV + GGML_OP_ADD (channel-wise bias) + GGML_OP_UNARY(SILU) (fused operation) +// GGML_OP_SSM_CONV + GGML_OP_ADD (channel-wise bias, optional) + GGML_OP_UNARY(SILU) (fused operation) struct test_ssm_conv_bias_silu : public test_case { const ggml_type type; const std::array ne_a; const std::array ne_b; const bool fuse_bias; - const bool fuse_silu; std::string op_desc(ggml_tensor * t) override { GGML_UNUSED(t); @@ -3595,15 +3594,12 @@ struct test_ssm_conv_bias_silu : public test_case { bool run_whole_graph() override { return true; } std::string vars() override { - return VARS_TO_STR5(type, ne_a, ne_b, fuse_bias, fuse_silu); + return VARS_TO_STR4(type, ne_a, ne_b, fuse_bias); } - test_ssm_conv_bias_silu(ggml_type type = GGML_TYPE_F32, - std::array ne_a = {4, 128, 1, 1}, - std::array ne_b = {4, 128, 1, 1}, - bool fuse_bias = true, - bool fuse_silu = true) - : type(type), ne_a(ne_a), ne_b(ne_b), fuse_bias(fuse_bias), fuse_silu(fuse_silu) {} + test_ssm_conv_bias_silu(ggml_type type, std::array ne_a, std::array ne_b, + bool fuse_bias) + : type(type), ne_a(ne_a), ne_b(ne_b), fuse_bias(fuse_bias) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); @@ -3619,9 +3615,7 @@ struct test_ssm_conv_bias_silu : public test_case { out = ggml_add(ctx, out, bias); } - if (fuse_silu) { - out = ggml_silu(ctx, out); - } + out = ggml_silu(ctx, out); ggml_set_name(out, "out"); return out; @@ -8031,19 +8025,18 @@ static std::vector> make_test_cases_eval() { for (int64_t d_conv : {3, 4, 9}) { for (int64_t d_inner : {1024, 1536, 2048}) { for (bool fuse_bias : {false, true}) { - const bool fuse_silu = true; // short token path (n_t <= 32) test_cases.emplace_back(new test_ssm_conv_bias_silu( - GGML_TYPE_F32, {d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}, fuse_bias, fuse_silu)); + GGML_TYPE_F32, {d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}, fuse_bias)); test_cases.emplace_back(new test_ssm_conv_bias_silu( - GGML_TYPE_F32, {2 * d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}, fuse_bias, fuse_silu)); + GGML_TYPE_F32, {2 * d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}, fuse_bias)); test_cases.emplace_back(new test_ssm_conv_bias_silu( - GGML_TYPE_F32, {d_conv, d_inner, 4, 1}, {d_conv, d_inner, 1, 1}, fuse_bias, fuse_silu)); + GGML_TYPE_F32, {d_conv, d_inner, 4, 1}, {d_conv, d_inner, 1, 1}, fuse_bias)); // long token path (n_t > 32) test_cases.emplace_back(new test_ssm_conv_bias_silu( - GGML_TYPE_F32, {d_conv - 1 + 64, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}, fuse_bias, fuse_silu)); + GGML_TYPE_F32, {d_conv - 1 + 64, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}, fuse_bias)); test_cases.emplace_back(new test_ssm_conv_bias_silu( - GGML_TYPE_F32, {d_conv - 1 + 64, d_inner, 4, 1}, {d_conv, d_inner, 1, 1}, fuse_bias, fuse_silu)); + GGML_TYPE_F32, {d_conv - 1 + 64, d_inner, 4, 1}, {d_conv, d_inner, 1, 1}, fuse_bias)); } } } @@ -9064,8 +9057,8 @@ static std::vector> make_test_cases_perf() { // Examples from granite-4.0-h-1b/ggml-model-Q8_0.gguf test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {515, 3328, 1, 1}, {4, 3328, 1, 1})); // prefill test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 3328, 1, 1}, {4, 3328, 1, 1})); // generate - test_cases.emplace_back(new test_ssm_conv_bias_silu(GGML_TYPE_F32, {515, 3328, 1, 1}, {4, 3328, 1, 1}, true, true)); // prefill - test_cases.emplace_back(new test_ssm_conv_bias_silu(GGML_TYPE_F32, {4, 3328, 1, 1}, {4, 3328, 1, 1}, true, true)); // generate + test_cases.emplace_back(new test_ssm_conv_bias_silu(GGML_TYPE_F32, {515, 3328, 1, 1}, {4, 3328, 1, 1}, true)); // prefill + test_cases.emplace_back(new test_ssm_conv_bias_silu(GGML_TYPE_F32, {4, 3328, 1, 1}, {4, 3328, 1, 1}, true)); // generate test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 512, 1)); // prefill test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 1, 1)); // generate