-
Notifications
You must be signed in to change notification settings - Fork 19.8k
CUDA: fuse SSM_CONV + ADD(bias) + SILU #22478
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,9 @@ | ||
| #include "ssm-conv.cuh" | ||
| #include "unary.cuh" | ||
|
|
||
| template <bool apply_silu, size_t split_d_inner, size_t d_conv> | ||
| template <bool apply_bias, bool apply_silu, size_t split_d_inner, size_t d_conv> | ||
| static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1, | ||
| const float * __restrict__ bias, | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it really necessary to template the kernel from a perf-perspective as opposed to checking bias against nullptr (this can be done in the same ternary expression)? We should be mindful of binary bloat and only template that which is truly necessary from a perf perspective. I'd imagine the same can potentially apply to apply_silu as well, but that's beyond the scope of this PR
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Makes sense. I've done as you've suggested now. |
||
| const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1, | ||
| float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2, | ||
| const int64_t n_t) { | ||
|
|
@@ -27,6 +28,8 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float | |
| w[j] = w_block[tid * stride_w + j]; | ||
| } | ||
|
|
||
| float b = apply_bias ? bias[bidy * split_d_inner + tid] : 0.0f; | ||
|
|
||
| for (int64_t i = 0; i < n_t; i++) { | ||
| float sumf = 0.0f; | ||
|
|
||
|
|
@@ -42,12 +45,14 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float | |
| for (size_t j = 0; j < d_conv; j++) { | ||
| sumf += x[(i + j) % d_conv] * w[j]; | ||
| } | ||
| sumf = apply_bias ? sumf + b : sumf; | ||
| y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf; | ||
| } | ||
| } | ||
|
|
||
| template <bool apply_silu, size_t split_d_inner, size_t d_conv, int64_t split_n_t> | ||
| template <bool apply_bias, bool apply_silu, size_t split_d_inner, size_t d_conv, int64_t split_n_t> | ||
| static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1, | ||
| const float * __restrict__ bias, | ||
|
anavp-nvidia marked this conversation as resolved.
Outdated
|
||
| const int src0_nb0, const int src0_nb1, const int src0_nb2, | ||
| const int src1_nb1, float * __restrict__ dst, const int dst_nb0, | ||
| const int dst_nb1, const int dst_nb2, const int64_t n_t) { | ||
|
|
@@ -97,19 +102,22 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, | |
| w[j] = w_block[tid * stride_w + j]; | ||
| } | ||
|
|
||
| float b = apply_bias ? bias[bidy * split_d_inner + tid] : 0.0f; | ||
|
|
||
| // Compute from shared memory | ||
| for (int64_t i = 0; i < local_n_t; i++) { | ||
| float sumf = 0.0f; | ||
| #pragma unroll | ||
| for (size_t j = 0; j < d_conv; j++) { | ||
| sumf += smem[tid * n_cols + i + j] * w[j]; | ||
| } | ||
| sumf = apply_bias ? sumf + b : sumf; | ||
| y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf; | ||
| } | ||
| } | ||
|
|
||
| template <bool apply_silu> | ||
| static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1, | ||
| template <bool apply_bias, bool apply_silu> | ||
| static void ssm_conv_f32_cuda(const float * src0, const float * src1, const float * bias, const int src0_nb0, const int src0_nb1, | ||
| const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1, | ||
| const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t, | ||
| const int64_t n_s, cudaStream_t stream) { | ||
|
|
@@ -120,14 +128,14 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int | |
| constexpr int kNC = decltype(NC)::value; | ||
| if (n_t <= 32) { | ||
| const dim3 blocks(n_s, (nr + threads - 1) / threads, 1); | ||
| ssm_conv_f32<apply_silu, threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, | ||
| ssm_conv_f32<apply_bias, apply_silu, threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1, | ||
| dst, dst_nb0, dst_nb1, dst_nb2, n_t); | ||
| } else { | ||
| const int64_t split_n_t = 32; | ||
| dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t); | ||
| const size_t smem_size = threads * (kNC - 1 + split_n_t) * sizeof(float); | ||
| ssm_conv_long_token_f32<apply_silu, threads, kNC, split_n_t><<<blocks, threads, smem_size, stream>>>( | ||
| src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); | ||
| ssm_conv_long_token_f32<apply_bias, apply_silu, threads, kNC, split_n_t><<<blocks, threads, smem_size, stream>>>( | ||
| src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t); | ||
| } | ||
| }; | ||
|
|
||
|
|
@@ -140,11 +148,18 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int | |
| } | ||
| } | ||
|
|
||
| void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst) { | ||
| void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * bias_add_node, ggml_tensor * silu_dst) { | ||
| const struct ggml_tensor * src0 = dst->src[0]; // conv_x | ||
| const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight | ||
| const bool fuse_bias = bias_add_node != nullptr; | ||
| const bool fuse_silu = silu_dst != nullptr; | ||
|
|
||
| // bias always comes with silu. | ||
| GGML_ASSERT(!fuse_bias || fuse_silu); | ||
|
|
||
| // The bias (when fused) is the non-conv operand of the ADD node. | ||
| const struct ggml_tensor * bias = fuse_bias ? (bias_add_node->src[0] == dst ? bias_add_node->src[1] : bias_add_node->src[0]) : nullptr; | ||
|
|
||
| // When fusing, write to silu_dst (the node downstream references). | ||
| const struct ggml_tensor * out = fuse_silu ? silu_dst : dst; | ||
|
|
||
|
|
@@ -160,16 +175,26 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, g | |
|
|
||
| const float * src0_d = (const float *) src0->data; | ||
| const float * src1_d = (const float *) src1->data; | ||
| const float * bias_d = fuse_bias ? (const float *) bias->data : nullptr; | ||
| float * dst_d = (float *) out->data; | ||
| cudaStream_t stream = ctx.stream(); | ||
|
|
||
| GGML_ASSERT(src0->type == GGML_TYPE_F32); | ||
| GGML_ASSERT(out->type == GGML_TYPE_F32); | ||
| if (fuse_silu) { | ||
| ssm_conv_f32_cuda<true>(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], | ||
| if (fuse_bias) { | ||
| GGML_ASSERT(bias->type == GGML_TYPE_F32); | ||
| GGML_ASSERT(ggml_is_contiguous(bias)); | ||
| GGML_ASSERT(ggml_nelements(bias) == nr); | ||
| } | ||
|
|
||
| if (fuse_bias) { | ||
| ssm_conv_f32_cuda<true, true>(src0_d, src1_d, bias_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], | ||
| out->nb[2], nc, nr, n_t, n_s, stream); | ||
| } else if (fuse_silu) { | ||
| ssm_conv_f32_cuda<false, true>(src0_d, src1_d, nullptr, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], | ||
| out->nb[2], nc, nr, n_t, n_s, stream); | ||
| } else { | ||
| ssm_conv_f32_cuda<false>(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], | ||
| ssm_conv_f32_cuda<false, false>(src0_d, src1_d, nullptr, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1], | ||
| out->nb[2], nc, nr, n_t, n_s, stream); | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,3 @@ | ||
| #include "common.cuh" | ||
|
|
||
| void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst = nullptr); | ||
| void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * bias_add_node = nullptr, ggml_tensor * silu_dst = nullptr); |
Uh oh!
There was an error while loading. Please reload this page.