Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3930,9 +3930,23 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph
const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0];

const bool type_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16);
// Kernel iterates over total = T * C, so x and add must be 2D and
// a / inv_b must collapse to [1, C, 1, 1]. Higher dims are not handled.
const bool dim_ok = (x->ne[2] == 1 && x->ne[3] == 1) &&
(add->ne[2] == 1 && add->ne[3] == 1) &&
(a->ne[2] == 1 && a->ne[3] == 1);
const bool shape_ok = ggml_are_same_shape(a, inv_b) && a->ne[0] == 1 && a->ne[1] == x->ne[1];

if (type_ok && shape_ok && x_in_add == x && add->type == x->type) {
// Defensive: every operand and intermediate result must share x's type,
// since launch_snake casts a / inv_b as float and templates the kernel
// on a single T. Mixed precision chains fall back to the naive path.
const ggml_tensor * sin1 = cgraph->nodes[i + 1];
const bool types_ok = (a->type == x->type) && (inv_b->type == x->type) &&
Comment thread
ServeurpersoCom marked this conversation as resolved.
Outdated
(mul0->type == x->type) && (sin1->type == x->type) &&
(sqr->type == x->type) && (mul1->type == x->type) &&
(add->type == x->type);

if (type_ok && shape_ok && dim_ok && types_ok && x_in_add == x) {
ggml_cuda_op_snake_fused(*cuda_ctx, x, a, inv_b, add);
return 4;
}
Expand Down