Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4567,7 +4567,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_f32_len, ssm_scan_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size, 16}, 1, true, true);
}

ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 1, 1}, {32}, 1);
ggml_vk_create_pipeline(device, device->pipeline_ssm_conv_f32, "ssm_conv_f32", ssm_conv_f32_len, ssm_conv_f32_data, "main", 3, sizeof(vk_op_ssm_conv_push_constants), {32, 16, 1}, {32, 16}, 1);

ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);

Expand Down
26 changes: 16 additions & 10 deletions ggml/src/ggml-vulkan/vulkan-shaders/ssm_conv.comp
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
#include "types.glsl"

layout(constant_id = 0) const uint BLOCK_SIZE = 32;
layout(constant_id = 1) const uint TOKENS_PER_WG = 16;

layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z = 1) in;

layout(binding = 0) readonly buffer Src0 { float src0[]; };
layout(binding = 1) readonly buffer Src1 { float src1[]; };
Expand All @@ -20,25 +21,30 @@ layout(push_constant) uniform PushConstants {
};

void main() {
const uint global_thread_id = gl_GlobalInvocationID.x;
const uint i2 = gl_WorkGroupID.y;
const uint i1 = gl_GlobalInvocationID.x;
const uint i2 = gl_WorkGroupID.y * TOKENS_PER_WG + gl_LocalInvocationID.y;
const uint i3 = gl_WorkGroupID.z;

if (global_thread_id >= nr || i2 >= n_t || i3 >= n_s) {
if (i1 >= nr || i2 >= n_t || i3 >= n_s) {
return;
}

const uint i1 = global_thread_id;
const uint src0_base = i3 * (nb02 / 4) + i2 + i1 * (nb01 / 4);
const uint src1_base = i1 * (nb11 / 4);
const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1;

float sum = 0.0;
[[unroll]] for (uint i0 = 0; i0 < nc; i0++) {
const uint src0_idx = src0_base + i0;
const uint src1_idx = src1_base + i0;
sum += src0[src0_idx] * src1[src1_idx];

if (nc == 4) {
sum = dot(
vec4(src0[src0_base], src0[src0_base + 1], src0[src0_base + 2], src0[src0_base + 3]),
vec4(src1[src1_base], src1[src1_base + 1], src1[src1_base + 2], src1[src1_base + 3])
);
} else {
[[unroll]] for (uint i0 = 0; i0 < nc; i0++) {
sum += src0[src0_base + i0] * src1[src1_base + i0];
}
}

const uint dst_idx = i3 * (dst_nb2 / 4) + i2 * (dst_nb1 / 4) + i1;
dst[dst_idx] = sum;
}
Loading