Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,12 @@ struct vk_device_struct {
vk_pipeline pipeline_pool2d_f32;
vk_pipeline pipeline_rwkv_wkv6_f32;
vk_pipeline pipeline_rwkv_wkv7_f32;
vk_pipeline pipeline_gated_delta_net_f32_d32;
vk_pipeline pipeline_gated_delta_net_f32_d64;
vk_pipeline pipeline_gated_delta_net_f32_d128;
vk_pipeline pipeline_gated_delta_net_f32_d32_kda;
vk_pipeline pipeline_gated_delta_net_f32_d64_kda;
vk_pipeline pipeline_gated_delta_net_f32_d128_kda;
vk_pipeline pipeline_ssm_scan_f32_d128;
vk_pipeline pipeline_ssm_scan_f32_d256;
vk_pipeline pipeline_ssm_conv_f32;
Expand Down Expand Up @@ -1439,6 +1445,17 @@ struct vk_op_rwkv_wkv7_push_constants {
uint32_t C;
uint32_t H;
};
struct vk_op_gated_delta_net_push_constants {
uint32_t H;
uint32_t n_tokens;
uint32_t n_seqs;
uint32_t s_off;
uint32_t sq1, sq2, sq3;
uint32_t sv1, sv2, sv3;
uint32_t sb1, sb2, sb3;
uint32_t rq1, rq3;
};

struct vk_op_ssm_scan_push_constants {
uint32_t nb02, nb03, nb12, nb13;
uint32_t nb21, nb22, nb31;
Expand Down Expand Up @@ -4559,6 +4576,13 @@ static void ggml_vk_load_shaders(vk_device& device) {

ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1);

ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_f32_d32, "gated_delta_net_f32_d32", gated_delta_net_f32_len, gated_delta_net_f32_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants), {1, 1, 1}, {32, 0}, 1);
ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_f32_d64, "gated_delta_net_f32_d64", gated_delta_net_f32_len, gated_delta_net_f32_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants), {1, 1, 1}, {64, 0}, 1);
ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_f32_d128, "gated_delta_net_f32_d128", gated_delta_net_f32_len, gated_delta_net_f32_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants), {1, 1, 1}, {128, 0}, 1);
ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_f32_d32_kda, "gated_delta_net_f32_d32_kda", gated_delta_net_f32_len, gated_delta_net_f32_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants), {1, 1, 1}, {32, 1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_f32_d64_kda, "gated_delta_net_f32_d64_kda", gated_delta_net_f32_len, gated_delta_net_f32_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants), {1, 1, 1}, {64, 1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_gated_delta_net_f32_d128_kda, "gated_delta_net_f32_d128_kda", gated_delta_net_f32_len, gated_delta_net_f32_data, "main", 7, sizeof(vk_op_gated_delta_net_push_constants), {1, 1, 1}, {128, 1}, 1);

if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d128, "ssm_scan_128_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {128, device->subgroup_size}, 1, true, true);
ggml_vk_create_pipeline(device, device->pipeline_ssm_scan_f32_d256, "ssm_scan_256_f32", ssm_scan_subgroup_f32_len, ssm_scan_subgroup_f32_data, "main", 8, sizeof(vk_op_ssm_scan_push_constants), {1, 1, 1}, {256, device->subgroup_size}, 1, true, true);
Expand Down Expand Up @@ -9478,6 +9502,25 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_rwkv_wkv7_f32;
}
return nullptr;
case GGML_OP_GATED_DELTA_NET:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
const uint32_t S_v = dst->src[2]->ne[0];
const bool kda = (dst->src[3]->ne[0] == (int64_t)S_v);
if (kda) {
switch (S_v) {
case 32: return ctx->device->pipeline_gated_delta_net_f32_d32_kda;
case 64: return ctx->device->pipeline_gated_delta_net_f32_d64_kda;
case 128: return ctx->device->pipeline_gated_delta_net_f32_d128_kda;
}
} else {
switch (S_v) {
case 32: return ctx->device->pipeline_gated_delta_net_f32_d32;
case 64: return ctx->device->pipeline_gated_delta_net_f32_d64;
case 128: return ctx->device->pipeline_gated_delta_net_f32_d128;
}
}
}
return nullptr;
case GGML_OP_SSM_SCAN:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
const uint32_t d_state = src0->ne[0];
Expand Down Expand Up @@ -10308,6 +10351,67 @@ static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx,
);
}

static void ggml_vk_gated_delta_net(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
const ggml_tensor * src_q = dst->src[0];
const ggml_tensor * src_k = dst->src[1];
const ggml_tensor * src_v = dst->src[2];
const ggml_tensor * src_g = dst->src[3];
const ggml_tensor * src_beta = dst->src[4];
const ggml_tensor * src_state = dst->src[5];

GGML_ASSERT(dst->buffer != nullptr);

const uint32_t S_v = (uint32_t)src_v->ne[0];
const uint32_t H = (uint32_t)src_v->ne[1];
const uint32_t n_tokens = (uint32_t)src_v->ne[2];
const uint32_t n_seqs = (uint32_t)src_v->ne[3];

const uint32_t s_off = S_v * H * n_tokens * n_seqs;

for (int i = 0; i < 6; i++) {
GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type));
}

vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op);
GGML_ASSERT(pipeline != nullptr);

ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);

vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
vk_subbuffer src_buf[6] = {};
for (int i = 0; i < 6; i++) {
src_buf[i] = ggml_vk_tensor_subbuffer(ctx, dst->src[i]);
}

// q and k share strides
const uint32_t sq1 = (uint32_t)(src_q->nb[1] / sizeof(float));
const uint32_t sq2 = (uint32_t)(src_q->nb[2] / sizeof(float));
const uint32_t sq3 = (uint32_t)(src_q->nb[3] / sizeof(float));
const uint32_t sv1 = (uint32_t)(src_v->nb[1] / sizeof(float));
const uint32_t sv2 = (uint32_t)(src_v->nb[2] / sizeof(float));
const uint32_t sv3 = (uint32_t)(src_v->nb[3] / sizeof(float));
const uint32_t sb1 = (uint32_t)(src_beta->nb[1] / sizeof(float));
const uint32_t sb2 = (uint32_t)(src_beta->nb[2] / sizeof(float));
const uint32_t sb3 = (uint32_t)(src_beta->nb[3] / sizeof(float));

const uint32_t rq1 = (uint32_t)(src_v->ne[1] / src_q->ne[1]);
const uint32_t rq3 = (uint32_t)(src_v->ne[3] / src_q->ne[3]);

const vk_op_gated_delta_net_push_constants pc = {
H, n_tokens, n_seqs, s_off,
sq1, sq2, sq3,
sv1, sv2, sv3,
sb1, sb2, sb3,
rq1, rq3
};

std::array<uint32_t, 3> elements = { H, n_seqs, 1 };

ggml_vk_dispatch_pipeline(ctx, subctx, pipeline,
{src_buf[0], src_buf[1], src_buf[2], src_buf[3], src_buf[4], src_buf[5], dst_buf},
pc, elements);
}

static void ggml_vk_ssm_scan(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
Expand Down Expand Up @@ -13024,6 +13128,11 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr

break;

case GGML_OP_GATED_DELTA_NET:
ggml_vk_gated_delta_net(ctx, compute_ctx, node);

break;

case GGML_OP_SSM_SCAN:
ggml_vk_ssm_scan(ctx, compute_ctx, node);

Expand Down Expand Up @@ -15426,6 +15535,19 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
return true; // all inputs are contiguous, see ggml.c
case GGML_OP_GATED_DELTA_NET:
{
const uint32_t S_v = op->src[2]->ne[0];
if (S_v != 32 && S_v != 64 && S_v != 128) {
return false;
}
for (int i = 0; i < 6; i++) {
if (op->src[i] && ggml_is_quantized(op->src[i]->type)) {
return false;
}
}
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
}
case GGML_OP_SSM_SCAN:
{
for (int i = 0; i < 6; i++) {
Expand Down Expand Up @@ -16299,6 +16421,9 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
} else if (tensor->op == GGML_OP_RWKV_WKV7) {
tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3],
src_clone[4], src_clone[5], src_clone[6]);
} else if (tensor->op == GGML_OP_GATED_DELTA_NET) {
tensor_clone = ggml_gated_delta_net(ggml_ctx, src_clone[0], src_clone[1],
src_clone[2], src_clone[3], src_clone[4], src_clone[5]);
} else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) {
src_clone[0]->flags = tensor->src[0]->flags;
tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1],
Expand Down
122 changes: 122 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/gated_delta_net.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
#version 450

#extension GL_EXT_control_flow_attributes : require

layout(constant_id = 0) const uint S_V = 128;
layout(constant_id = 1) const uint KDA = 0;

layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

layout(push_constant) uniform Parameters {
uint H;
uint n_tokens;
uint n_seqs;
uint s_off;
uint sq1, sq2, sq3;
uint sv1, sv2, sv3;
uint sb1, sb2, sb3;
uint rq1, rq3;
};

layout(binding = 0) readonly buffer QBuf { float q_in[]; };
layout(binding = 1) readonly buffer KBuf { float k_in[]; };
layout(binding = 2) readonly buffer VBuf { float v_in[]; };
layout(binding = 3) readonly buffer GBuf { float g_in[]; };
layout(binding = 4) readonly buffer BetaBuf { float beta_in[]; };
layout(binding = 5) readonly buffer StateBuf { float state_in[]; };
layout(binding = 6) buffer DstBuf { float dst[]; };

shared float s_k[S_V];
shared float s_q[S_V];

void main() {
const uint head_id = gl_WorkGroupID.x;
const uint seq_id = gl_WorkGroupID.y;
const uint col = gl_LocalInvocationID.x;

if (col >= S_V) {
return;
}

const uint iq1 = head_id / rq1;
const uint iq3 = seq_id / rq3;

const uint state_size = S_V * S_V;
const uint state_base = (seq_id * H + head_id) * state_size;

const float scale = 1.0 / sqrt(float(S_V));

// Load state column into registers: S[i][col] for all rows i
float state[S_V];
[[unroll]] for (uint i = 0; i < S_V; i++) {
state[i] = state_in[state_base + i * S_V + col];
}

uint attn_off = (seq_id * n_tokens * H + head_id) * S_V;

for (uint t = 0; t < n_tokens; t++) {
// Load q and k into shared memory
const uint q_off = iq3 * sq3 + t * sq2 + iq1 * sq1;
const uint k_off = q_off; // q and k share strides (asserted in ggml.c)
const uint v_off = seq_id * sv3 + t * sv2 + head_id * sv1;

s_q[col] = q_in[q_off + col];
s_k[col] = k_in[k_off + col];
barrier();

const float v_val = v_in[v_off + col];

// beta and gate offsets
const uint gb_off = seq_id * sb3 + t * sb2 + head_id * sb1;
const float beta_val = beta_in[gb_off];

if (KDA == 0) {
const float g_val = exp(g_in[gb_off]);

// kv_col = (S^T @ k)[col] = sum_i S[i][col] * k[i]
float kv_col = 0.0;
[[unroll]] for (uint i = 0; i < S_V; i++) {
kv_col += state[i] * s_k[i];
}

// delta_col = (v[col] - g * kv_col) * beta
float delta_col = (v_val - g_val * kv_col) * beta_val;

// Fused decay + rank-1 update + readout
float attn_col = 0.0;
[[unroll]] for (uint i = 0; i < S_V; i++) {
state[i] = g_val * state[i] + s_k[i] * delta_col;
attn_col += state[i] * s_q[i];
}

dst[attn_off + col] = attn_col * scale;
} else {
// KDA: per-row vector gate
const uint g_base = gb_off * S_V;

// kv_col = sum_i exp(g[i]) * S[i][col] * k[i]
float kv_col = 0.0;
[[unroll]] for (uint i = 0; i < S_V; i++) {
kv_col += exp(g_in[g_base + i]) * state[i] * s_k[i];
}

float delta_col = (v_val - kv_col) * beta_val;

float attn_col = 0.0;
[[unroll]] for (uint i = 0; i < S_V; i++) {
state[i] = exp(g_in[g_base + i]) * state[i] + s_k[i] * delta_col;
attn_col += state[i] * s_q[i];
}

dst[attn_off + col] = attn_col * scale;
}

attn_off += S_V * H;
barrier();
}

// Write final state
[[unroll]] for (uint i = 0; i < S_V; i++) {
dst[s_off + state_base + i * S_V + col] = state[i];
}
}
2 changes: 2 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,8 @@ void process_shaders() {

string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));

string_to_spv("gated_delta_net_f32", "gated_delta_net.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));

string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));

Expand Down