Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10468,7 +10468,7 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1);

const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1);
const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);
const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);

if (kda) {
for (int64_t i = 0; i < S_v; ++i) {
Expand Down Expand Up @@ -10501,7 +10501,6 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(

attn_data += S_v * H; // advance to next token
}

}
}

Expand Down
35 changes: 35 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,41 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_
return res;
}

ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(ggml_metal_library_t lib, const ggml_tensor * op) {
char base[256];
char name[256];

// v is src[2], dimensions: S_v = ne[0], H = ne[1]
const int ne20 = op->src[2]->ne[0]; // S_v
const int ne21 = op->src[2]->ne[1]; // H
const int ne30 = op->src[3]->ne[0]; // G

const int nsg = op->src[2]->ne[0]/32;

GGML_ASSERT(op->src[5]->type == GGML_TYPE_F32);
GGML_ASSERT(op->ne[0] == ne20 * ne21);
GGML_ASSERT(ne20 % 32 == 0);

snprintf(base, 256, "kernel_gated_delta_net_%s_%d", ggml_type_name(op->src[0]->type), nsg);
snprintf(name, 256, "%s_ne20=%d_ne30=%d", base, ne20, ne30);

ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
ggml_metal_cv_t cv = ggml_metal_cv_init();

ggml_metal_cv_set_int16(cv, ne20, FC_GATED_DELTA_NET + 0);
ggml_metal_cv_set_int16(cv, ne30, FC_GATED_DELTA_NET + 1);

res = ggml_metal_library_compile_pipeline(lib, base, name, cv);

ggml_metal_cv_free(cv);
}

res.nsg = nsg;

return res;
}

ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
char base[256];
char name[256];
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-device.m
Original file line number Diff line number Diff line change
Expand Up @@ -1155,6 +1155,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
return true;
case GGML_OP_GATED_DELTA_NET:
return op->src[2]->ne[0] % 32 == 0;
case GGML_OP_SOLVE_TRI:
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
Expand Down
39 changes: 39 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
#define FC_BIN 1300
#define FC_SUM_ROWS 1400
#define FC_UPSCALE 1500
#define FC_GATED_DELTA_NET 1600

// op-specific constants
#define OP_FLASH_ATTN_EXT_NQPSG 8
Expand Down Expand Up @@ -793,6 +794,44 @@ typedef struct {
uint64_t nb0;
} ggml_metal_kargs_ssm_scan;

typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne10;
int32_t ne11;
int32_t ne12;
int32_t ne13;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ne20;
int32_t ne21;
int32_t ne22;
int32_t ne23;
uint64_t nb20;
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
int32_t ns02;
int32_t ns12;
int32_t ns22;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} ggml_metal_kargs_gated_delta_net;

typedef struct {
int32_t ne00;
int32_t ne01;
Expand Down
79 changes: 79 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_rwkv(ctx, idx);
} break;
case GGML_OP_GATED_DELTA_NET:
{
n_fuse = ggml_metal_op_gated_delta_net(ctx, idx);
} break;
case GGML_OP_SOLVE_TRI:
{
n_fuse = ggml_metal_op_solve_tri(ctx, idx);
Expand Down Expand Up @@ -1562,6 +1566,81 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
return 1;
}

int ggml_metal_op_gated_delta_net(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);

ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;


GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);

auto pipeline = ggml_metal_library_get_pipeline_gated_delta_net(lib, op);

int ida = 0;

ggml_metal_kargs_gated_delta_net args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne10 =*/ ne10,
/*.ne11 =*/ ne11,
/*.ne12 =*/ ne12,
/*.ne13 =*/ ne13,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.ne20 =*/ ne20,
/*.ne21 =*/ ne21,
/*.ne22 =*/ ne22,
/*.ne23 =*/ ne23,
/*.nb20 =*/ nb20,
/*.nb21 =*/ nb21,
/*.nb22 =*/ nb22,
/*.nb23 =*/ nb23,
/*.ns02 =*/ (int32_t) (nb02/sizeof(float)),
/*.ns12 =*/ (int32_t) (nb12/sizeof(float)),
/*.ns22 =*/ (int32_t) (nb22/sizeof(float)),
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
};

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); // q
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); // k
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); // v
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); // gate
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); // beta
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++); // state
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++); // dst

const int nsg = pipeline.nsg;

ggml_metal_encoder_dispatch_threadgroups(enc, op->src[2]->ne[0]/nsg, op->src[2]->ne[1], op->src[2]->ne[3], 32, nsg, 1);

return 1;
}

int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);

Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_gated_delta_net (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_set (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx);
Expand Down
Loading
Loading