Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 134 additions & 2 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,12 @@ static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGM
GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };

// Snake activation: y = x + sin(a*x)^2 * inv_b. Used by the optimize_graph reorder
// pass so it keeps the chain contiguous and by the dispatcher to detect the fusion.
static constexpr std::initializer_list<ggml_op> snake_pattern { GGML_OP_MUL, GGML_OP_SIN,
GGML_OP_SQR, GGML_OP_MUL,
GGML_OP_ADD };

//node #978 ( SOFT_MAX): ffn_moe_probs-15 ( 0K) [Vulka ] use=2: ffn_moe_logits-15 ( 0K) [Vulka ]
//node #979 ( RESHAPE): ffn_moe_probs-15 (re ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
//node #980 ( ARGSORT): ffn_moe_argsort-15 ( 0K) [Vulka ] use=1: ffn_moe_probs-15 ( 0K) [Vulka ]
Expand Down Expand Up @@ -838,6 +844,9 @@ struct vk_device_struct {
vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16;
vk_pipeline pipeline_timestep_embedding_f32;
vk_pipeline pipeline_conv_transpose_1d_f32;
vk_pipeline pipeline_snake_f32;
vk_pipeline pipeline_snake_f16;
vk_pipeline pipeline_snake_bf16;
vk_pipeline pipeline_pool2d_f32;
vk_pipeline pipeline_rwkv_wkv6_f32;
vk_pipeline pipeline_rwkv_wkv7_f32;
Expand Down Expand Up @@ -1463,6 +1472,11 @@ struct vk_op_conv_transpose_1d_push_constants {
int32_t s0;
};

struct vk_op_snake_push_constants {
uint32_t ne0;
uint32_t ne1;
};

struct vk_op_pool2d_push_constants {
uint32_t IW; uint32_t IH;
uint32_t OW; uint32_t OH;
Expand Down Expand Up @@ -4753,6 +4767,10 @@ static void ggml_vk_load_shaders(vk_device& device) {

ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1);

ggml_vk_create_pipeline(device, device->pipeline_snake_f32, "snake_f32", snake_f32_len, snake_f32_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_snake_f16, "snake_f16", snake_f16_len, snake_f16_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_snake_bf16, "snake_bf16", snake_bf16_len, snake_bf16_data, "main", 4, sizeof(vk_op_snake_push_constants), {256, 1, 1}, {}, 1);

ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1);

ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);
Expand Down Expand Up @@ -11957,6 +11975,45 @@ static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context&
ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p));
}

// Dispatch the fused snake activation: y = x + sin^2(a * x) * inv_b.
// Match the naive mul -> sin -> sqr -> mul -> add chain and run the
// dedicated kernel directly. The pattern is validated by
// ggml_vk_can_fuse_snake before this call.
static void ggml_vk_snake_dispatch_fused(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) {
const ggml_tensor * mul0 = cgraph->nodes[node_idx + 0];
const ggml_tensor * sqr = cgraph->nodes[node_idx + 2];
const ggml_tensor * mul1 = cgraph->nodes[node_idx + 3];
ggml_tensor * add = cgraph->nodes[node_idx + 4];

// x carries the full activation shape, a is the broadcast operand
const ggml_tensor * x = ggml_are_same_shape(mul0, mul0->src[0]) ? mul0->src[0] : mul0->src[1];
Comment thread
ServeurpersoCom marked this conversation as resolved.
const ggml_tensor * a = (x == mul0->src[0]) ? mul0->src[1] : mul0->src[0];

// mul1 reads sqr and inv_b in either operand order
const ggml_tensor * inv_b = (mul1->src[0] == sqr) ? mul1->src[1] : mul1->src[0];

vk_pipeline pipeline = nullptr;
switch (x->type) {
case GGML_TYPE_F32: pipeline = ctx->device->pipeline_snake_f32; break;
case GGML_TYPE_F16: pipeline = ctx->device->pipeline_snake_f16; break;
case GGML_TYPE_BF16: pipeline = ctx->device->pipeline_snake_bf16; break;
default: GGML_ABORT("unsupported type");
}
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);

vk_subbuffer x_buf = ggml_vk_tensor_subbuffer(ctx, x);
vk_subbuffer a_buf = ggml_vk_tensor_subbuffer(ctx, a);
vk_subbuffer inv_b_buf = ggml_vk_tensor_subbuffer(ctx, inv_b);
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, add);

vk_op_snake_push_constants pc{};
pc.ne0 = static_cast<uint32_t>(x->ne[0]);
pc.ne1 = static_cast<uint32_t>(x->ne[1]);

std::array<uint32_t, 3> elements = { pc.ne0, pc.ne1, 1 };
Comment thread
ServeurpersoCom marked this conversation as resolved.
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { x_buf, a_buf, inv_b_buf, dst_buf }, pc, elements);
}

static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
uint32_t op = static_cast<uint32_t>(dst->op_params[0]);
const int32_t k1 = dst->op_params[1];
Expand Down Expand Up @@ -13165,7 +13222,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr

break;
case GGML_OP_MUL:
ggml_vk_mul(ctx, compute_ctx, src0, src1, node);
if (ctx->num_additional_fused_ops) {
ggml_vk_snake_dispatch_fused(ctx, compute_ctx, cgraph, node_idx);
} else {
ggml_vk_mul(ctx, compute_ctx, src0, src1, node);
}

break;
case GGML_OP_DIV:
Expand Down Expand Up @@ -14482,6 +14543,65 @@ static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const
return true;
}

// Pattern check for the 5-op Snake fusion: mul -> sin -> sqr -> mul -> add.
// Verifies the chain shape, the closure x_in_add == x_in_mul0, and that
// the broadcast operands a and inv_b share a [1, C] layout.
static bool ggml_vk_can_fuse_snake(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) {
GGML_UNUSED(ctx);
if (!ggml_can_fuse(cgraph, node_idx, snake_pattern)) {
return false;
}

const ggml_tensor * mul0 = cgraph->nodes[node_idx + 0];
const ggml_tensor * sin_node = cgraph->nodes[node_idx + 1];
const ggml_tensor * sqr = cgraph->nodes[node_idx + 2];
const ggml_tensor * mul1 = cgraph->nodes[node_idx + 3];
const ggml_tensor * add = cgraph->nodes[node_idx + 4];

const ggml_tensor * x = ggml_are_same_shape(mul0, mul0->src[0]) ? mul0->src[0] : mul0->src[1];
const ggml_tensor * a = (x == mul0->src[0]) ? mul0->src[1] : mul0->src[0];

const ggml_tensor * inv_b = (mul1->src[0] == sqr) ? mul1->src[1] : mul1->src[0];
const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0];

if (x_in_add != x) {
return false;
}
if (x->type != GGML_TYPE_F32 && x->type != GGML_TYPE_F16 && x->type != GGML_TYPE_BF16) {
Comment thread
ServeurpersoCom marked this conversation as resolved.
return false;
}
// Shader bindings: data_a is A_TYPE so it follows x's precision, while
// data_b and data_c are hardcoded float, so the broadcast operands must
// be F32 regardless of x's type.
if (a->type != GGML_TYPE_F32) return false;
if (inv_b->type != GGML_TYPE_F32) return false;
// Chain intermediates and output share x's precision (single A_TYPE / D_TYPE pipeline).
if (mul0->type != x->type) return false;
if (sin_node->type != x->type) return false;
if (sqr->type != x->type) return false;
if (mul1->type != x->type) return false;
if (add->type != x->type) return false;
if (!ggml_are_same_shape(a, inv_b)) {
return false;
}
if (a->ne[0] != 1 || a->ne[1] != x->ne[1]) {
return false;
}
// Dispatch is 2D over (ne0, ne1), so x and add must be 2D and a / inv_b
// must collapse to [1, C, 1, 1]. Higher dims are not handled by the shader.
if (x->ne[2] != 1 || x->ne[3] != 1) return false;
if (add->ne[2] != 1 || add->ne[3] != 1) return false;
if (a->ne[2] != 1 || a->ne[3] != 1) return false;
if (inv_b->ne[2] != 1 || inv_b->ne[3] != 1) return false;
// Shader uses idx = i0 + i1 * ne0 and reads data_b[i1] / data_c[i1],
// so every operand must be contiguous.
if (!ggml_is_contiguous(x) || !ggml_is_contiguous(add) ||
!ggml_is_contiguous(a) || !ggml_is_contiguous(inv_b)) {
return false;
}
return true;
}

// Check whether the tensors overlap in memory.
// Fusions can potentially overwrite src tensors in ways that are not prevented
// by ggml-alloc. If the fusion src is being applied in a way that's elementwise
Expand Down Expand Up @@ -14776,6 +14896,14 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
op_srcs_fused_elementwise[0] = false;
op_srcs_fused_elementwise[1] = false;
op_srcs_fused_elementwise[2] = false;
} else if (ggml_vk_can_fuse_snake(ctx, cgraph, i)) {
ctx->num_additional_fused_ops = 4;
fusion_string = "SNAKE";
// elementwise=true: snake.comp is safe under exact aliasing because each
// thread reads data_x[idx] into a register before writing data_d[idx]
// with a data dependency on that register. The overlap check still
// rejects partial overlaps (different base or size).
std::fill_n(op_srcs_fused_elementwise, 5, true);
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) &&
ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) &&
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) {
Expand Down Expand Up @@ -15066,6 +15194,9 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
if (keep_pattern(topk_moe_late_softmax)) {
continue;
}
if (keep_pattern(snake_pattern)) {
continue;
}

// First, grab the next unused node.
current_set.push_back(first_unused);
Expand All @@ -15088,7 +15219,8 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
if (match_pattern(topk_moe_early_softmax_norm, j) ||
match_pattern(topk_moe_sigmoid_norm_bias, j) ||
match_pattern(topk_moe_early_softmax, j) ||
match_pattern(topk_moe_late_softmax, j)) {
match_pattern(topk_moe_late_softmax, j) ||
match_pattern(snake_pattern, j)) {
continue;
}
bool ok = true;
Expand Down
49 changes: 49 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/snake.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#version 450

#include "types.glsl"

// Fused snake activation: y = x + sin(b * x)^2 * c
// data_a [ne0, ne1] per element activation x (A_TYPE)
// data_b [1, ne1] per channel multiplier (float)
// data_c [1, ne1] per channel inverse scale (float, precomputed as 1 / freq)
// data_d [ne0, ne1] output y (D_TYPE)
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {float data_b[];};
layout (binding = 2) readonly buffer C {float data_c[];};
layout (binding = 3) writeonly buffer D {D_TYPE data_d[];};

layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;

layout (push_constant) uniform parameter {
uint32_t ne0;
uint32_t ne1;
} p;

// Load A_TYPE to float
float load_val(uint32_t idx) {
#if defined(DATA_A_BF16)
return bf16_to_fp32(uint32_t(data_a[idx]));
#else
return float(data_a[idx]);
#endif
}

// Store float as D_TYPE
void store_val(uint32_t idx, float v) {
#if defined(DATA_D_BF16)
data_d[idx] = D_TYPE(fp32_to_bf16(v));
#else
data_d[idx] = D_TYPE(v);
#endif
}

void main() {
const uint32_t i0 = gl_GlobalInvocationID.x;
const uint32_t i1 = gl_GlobalInvocationID.y;
if (i0 >= p.ne0 || i1 >= p.ne1) return;

const uint32_t idx = i0 + i1 * p.ne0;
const float xi = load_val(idx);
const float s = sin(data_b[i1] * xi);
store_val(idx, xi + s * s * data_c[i1]);
}
4 changes: 4 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -970,6 +970,10 @@ void process_shaders() {

string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});

string_to_spv("snake_f32", "snake.comp", {{"DATA_A_F32", "1"}, {"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("snake_f16", "snake.comp", {{"DATA_A_F16", "1"}, {"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("snake_bf16", "snake.comp", {{"DATA_A_BF16", "1"}, {"DATA_D_BF16", "1"}, {"A_TYPE", "uint16_t"}, {"D_TYPE", "uint16_t"}});

string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));

string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
Expand Down
Loading