Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3556,6 +3556,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
&& unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) {
const ggml_tensor * ssm_conv = cgraph->nodes[node_idx];
const ggml_tensor * silu = cgraph->nodes[node_idx+1];
Comment thread
anavp-nvidia marked this conversation as resolved.
if (ggml_get_unary_op(silu) != unary_ops.begin()[0]) {
return false;
}

if (ssm_conv->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) {
return false;
Expand All @@ -3564,6 +3567,31 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
return true;
}

if (ops.size() == 3 && ops.begin()[0] == GGML_OP_SSM_CONV && ops.begin()[1] == GGML_OP_ADD
&& ops.begin()[2] == GGML_OP_UNARY && unary_ops.size() == 1 && unary_ops.begin()[0] == GGML_UNARY_OP_SILU) {
const ggml_tensor * ssm_conv = cgraph->nodes[node_idx];
const ggml_tensor * add = cgraph->nodes[node_idx+1];
const ggml_tensor * silu = cgraph->nodes[node_idx+2];
Comment thread
anavp-nvidia marked this conversation as resolved.
if (ggml_get_unary_op(silu) != unary_ops.begin()[0]) {
return false;
}

if (ssm_conv->type != GGML_TYPE_F32 || add->type != GGML_TYPE_F32 || silu->type != GGML_TYPE_F32) {
return false;
}

// ADD must consume ssm_conv's output and broadcast a 1-D channel-wise bias.
const ggml_tensor * bias = (add->src[0] == ssm_conv) ? add->src[1] : add->src[0];
if (bias->type != GGML_TYPE_F32 || !ggml_is_contiguous(bias)) {
return false;
}
if (ggml_nelements(bias) != ssm_conv->ne[0] || bias->ne[0] != ssm_conv->ne[0]) {
return false;
}

return true;
}

if (ops.size() == 2 && ops.begin()[0] == GGML_OP_UNARY && ops.begin()[1] == GGML_OP_MUL
&& unary_ops.size() == 1 && (unary_ops.begin()[0] == GGML_UNARY_OP_SILU || unary_ops.begin()[0] == GGML_UNARY_OP_SIGMOID || unary_ops.begin()[0] == GGML_UNARY_OP_SOFTPLUS)) {
const ggml_tensor * unary = cgraph->nodes[node_idx];
Expand Down Expand Up @@ -3966,8 +3994,13 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph
return 1;
}

if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_ADD, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) {
ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i + 1], cgraph->nodes[i + 2]);
return 2;
}

if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SSM_CONV, GGML_OP_UNARY }, { GGML_UNARY_OP_SILU })) {
ggml_cuda_op_ssm_conv(*cuda_ctx, node, cgraph->nodes[i + 1]);
ggml_cuda_op_ssm_conv(*cuda_ctx, node, /*bias_add_node=*/ nullptr, cgraph->nodes[i + 1]);
return 1;
}

Expand Down
47 changes: 36 additions & 11 deletions ggml/src/ggml-cuda/ssm-conv.cu
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#include "ssm-conv.cuh"
#include "unary.cuh"

template <bool apply_silu, size_t split_d_inner, size_t d_conv>
template <bool apply_bias, bool apply_silu, size_t split_d_inner, size_t d_conv>
static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1,
const float * __restrict__ bias,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really necessary to template the kernel from a perf-perspective as opposed to checking bias against nullptr (this can be done in the same ternary expression)? We should be mindful of binary bloat and only template that which is truly necessary from a perf perspective.

I'd imagine the same can potentially apply to apply_silu as well, but that's beyond the scope of this PR

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I've done as you've suggested now.

const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,
float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,
const int64_t n_t) {
Expand All @@ -27,6 +28,8 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float
w[j] = w_block[tid * stride_w + j];
}

float b = apply_bias ? bias[bidy * split_d_inner + tid] : 0.0f;

for (int64_t i = 0; i < n_t; i++) {
float sumf = 0.0f;

Expand All @@ -42,12 +45,14 @@ static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float
for (size_t j = 0; j < d_conv; j++) {
sumf += x[(i + j) % d_conv] * w[j];
}
sumf = apply_bias ? sumf + b : sumf;
y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf;
}
}

template <bool apply_silu, size_t split_d_inner, size_t d_conv, int64_t split_n_t>
template <bool apply_bias, bool apply_silu, size_t split_d_inner, size_t d_conv, int64_t split_n_t>
static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1,
const float * __restrict__ bias,
Comment thread
anavp-nvidia marked this conversation as resolved.
Outdated
const int src0_nb0, const int src0_nb1, const int src0_nb2,
const int src1_nb1, float * __restrict__ dst, const int dst_nb0,
const int dst_nb1, const int dst_nb2, const int64_t n_t) {
Expand Down Expand Up @@ -97,19 +102,22 @@ static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0,
w[j] = w_block[tid * stride_w + j];
}

float b = apply_bias ? bias[bidy * split_d_inner + tid] : 0.0f;

// Compute from shared memory
for (int64_t i = 0; i < local_n_t; i++) {
float sumf = 0.0f;
#pragma unroll
for (size_t j = 0; j < d_conv; j++) {
sumf += smem[tid * n_cols + i + j] * w[j];
}
sumf = apply_bias ? sumf + b : sumf;
y_block[i * stride_y + tid] = apply_silu ? ggml_cuda_op_silu_single(sumf) : sumf;
}
}

template <bool apply_silu>
static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1,
template <bool apply_bias, bool apply_silu>
static void ssm_conv_f32_cuda(const float * src0, const float * src1, const float * bias, const int src0_nb0, const int src0_nb1,
const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1,
const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t,
const int64_t n_s, cudaStream_t stream) {
Expand All @@ -120,14 +128,14 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
constexpr int kNC = decltype(NC)::value;
if (n_t <= 32) {
const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
ssm_conv_f32<apply_silu, threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
ssm_conv_f32<apply_bias, apply_silu, threads, kNC><<<blocks, threads, 0, stream>>>(src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
} else {
const int64_t split_n_t = 32;
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
const size_t smem_size = threads * (kNC - 1 + split_n_t) * sizeof(float);
ssm_conv_long_token_f32<apply_silu, threads, kNC, split_n_t><<<blocks, threads, smem_size, stream>>>(
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
ssm_conv_long_token_f32<apply_bias, apply_silu, threads, kNC, split_n_t><<<blocks, threads, smem_size, stream>>>(
src0, src1, bias, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
}
};

Expand All @@ -140,11 +148,18 @@ static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int
}
}

void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst) {
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * bias_add_node, ggml_tensor * silu_dst) {
const struct ggml_tensor * src0 = dst->src[0]; // conv_x
const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
const bool fuse_bias = bias_add_node != nullptr;
const bool fuse_silu = silu_dst != nullptr;

// bias always comes with silu.
GGML_ASSERT(!fuse_bias || fuse_silu);

// The bias (when fused) is the non-conv operand of the ADD node.
const struct ggml_tensor * bias = fuse_bias ? (bias_add_node->src[0] == dst ? bias_add_node->src[1] : bias_add_node->src[0]) : nullptr;

// When fusing, write to silu_dst (the node downstream references).
const struct ggml_tensor * out = fuse_silu ? silu_dst : dst;

Expand All @@ -160,16 +175,26 @@ void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, g

const float * src0_d = (const float *) src0->data;
const float * src1_d = (const float *) src1->data;
const float * bias_d = fuse_bias ? (const float *) bias->data : nullptr;
float * dst_d = (float *) out->data;
cudaStream_t stream = ctx.stream();

GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(out->type == GGML_TYPE_F32);
if (fuse_silu) {
ssm_conv_f32_cuda<true>(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1],
if (fuse_bias) {
GGML_ASSERT(bias->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(bias));
GGML_ASSERT(ggml_nelements(bias) == nr);
}

if (fuse_bias) {
ssm_conv_f32_cuda<true, true>(src0_d, src1_d, bias_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1],
out->nb[2], nc, nr, n_t, n_s, stream);
} else if (fuse_silu) {
ssm_conv_f32_cuda<false, true>(src0_d, src1_d, nullptr, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1],
out->nb[2], nc, nr, n_t, n_s, stream);
} else {
ssm_conv_f32_cuda<false>(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1],
ssm_conv_f32_cuda<false, false>(src0_d, src1_d, nullptr, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, out->nb[0], out->nb[1],
out->nb[2], nc, nr, n_t, n_s, stream);
}
}
2 changes: 1 addition & 1 deletion ggml/src/ggml-cuda/ssm-conv.cuh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#include "common.cuh"

void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * silu_dst = nullptr);
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * bias_add_node = nullptr, ggml_tensor * silu_dst = nullptr);
73 changes: 73 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3579,6 +3579,55 @@ struct test_ssm_conv : public test_case {
}
};

// GGML_OP_SSM_CONV + GGML_OP_ADD (channel-wise bias) + GGML_OP_UNARY(SILU) (fused operation)
struct test_ssm_conv_bias_silu : public test_case {
Comment thread
anavp-nvidia marked this conversation as resolved.
const ggml_type type;
const std::array<int64_t, 4> ne_a;
const std::array<int64_t, 4> ne_b;
const bool fuse_bias;
const bool fuse_silu;

std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
return "SSM_CONV_BIAS_SILU";
}

bool run_whole_graph() override { return true; }

std::string vars() override {
return VARS_TO_STR5(type, ne_a, ne_b, fuse_bias, fuse_silu);
}

test_ssm_conv_bias_silu(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne_a = {4, 128, 1, 1},
std::array<int64_t, 4> ne_b = {4, 128, 1, 1},
bool fuse_bias = true,
bool fuse_silu = true)
Comment thread
anavp-nvidia marked this conversation as resolved.
Outdated
: type(type), ne_a(ne_a), ne_b(ne_b), fuse_bias(fuse_bias), fuse_silu(fuse_silu) {}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data());
ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data());
ggml_set_name(a, "a");
ggml_set_name(b, "b");

ggml_tensor * out = ggml_ssm_conv(ctx, a, b);

if (fuse_bias) {
ggml_tensor * bias = ggml_new_tensor_1d(ctx, type, out->ne[0]);
ggml_set_name(bias, "bias");
out = ggml_add(ctx, out, bias);
}

if (fuse_silu) {
out = ggml_silu(ctx, out);
}

ggml_set_name(out, "out");
return out;
}
};

// GGML_OP_SSM_SCAN
struct test_ssm_scan : public test_case {
const ggml_type type;
Expand Down Expand Up @@ -7977,6 +8026,28 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
}
}

// fused ssm_conv + (optional) bias_add + silu. The bias-only graph (no silu) is intentionally
// not tested since there's no fusion for that pattern in ggml_cuda_can_fuse.
for (int64_t d_conv : {3, 4, 9}) {
for (int64_t d_inner : {1024, 1536, 2048}) {
for (bool fuse_bias : {false, true}) {
const bool fuse_silu = true;
// short token path (n_t <= 32)
test_cases.emplace_back(new test_ssm_conv_bias_silu(
GGML_TYPE_F32, {d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}, fuse_bias, fuse_silu));
test_cases.emplace_back(new test_ssm_conv_bias_silu(
GGML_TYPE_F32, {2 * d_conv, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}, fuse_bias, fuse_silu));
test_cases.emplace_back(new test_ssm_conv_bias_silu(
GGML_TYPE_F32, {d_conv, d_inner, 4, 1}, {d_conv, d_inner, 1, 1}, fuse_bias, fuse_silu));
// long token path (n_t > 32)
test_cases.emplace_back(new test_ssm_conv_bias_silu(
GGML_TYPE_F32, {d_conv - 1 + 64, d_inner, 1, 1}, {d_conv, d_inner, 1, 1}, fuse_bias, fuse_silu));
test_cases.emplace_back(new test_ssm_conv_bias_silu(
GGML_TYPE_F32, {d_conv - 1 + 64, d_inner, 4, 1}, {d_conv, d_inner, 1, 1}, fuse_bias, fuse_silu));
}
}
}

test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 16, 2, 32, 4)); // Mamba-2
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 256, 64, 8, 2, 32, 4)); // Falcon-H1
Expand Down Expand Up @@ -8993,6 +9064,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
// Examples from granite-4.0-h-1b/ggml-model-Q8_0.gguf
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {515, 3328, 1, 1}, {4, 3328, 1, 1})); // prefill
test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 3328, 1, 1}, {4, 3328, 1, 1})); // generate
test_cases.emplace_back(new test_ssm_conv_bias_silu(GGML_TYPE_F32, {515, 3328, 1, 1}, {4, 3328, 1, 1}, true, true)); // prefill
test_cases.emplace_back(new test_ssm_conv_bias_silu(GGML_TYPE_F32, {4, 3328, 1, 1}, {4, 3328, 1, 1}, true, true)); // generate
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 512, 1)); // prefill
test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 1, 1)); // generate

Expand Down
Loading