Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions ggml/src/ggml-cpu/ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10468,7 +10468,7 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(
const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1);

const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1);
const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);
const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1);

if (kda) {
for (int64_t i = 0; i < S_v; ++i) {
Expand Down Expand Up @@ -10501,7 +10501,6 @@ static void ggml_compute_forward_gated_delta_net_one_chunk(

attn_data += S_v * H; // advance to next token
}

}
}

Expand Down
35 changes: 35 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,41 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_
return res;
}

ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(ggml_metal_library_t lib, const ggml_tensor * op) {
char base[256];
char name[256];

// v is src[2], dimensions: S_v = ne[0], H = ne[1]
const int ne20 = op->src[2]->ne[0]; // S_v
const int ne21 = op->src[2]->ne[1]; // H
const int ne30 = op->src[3]->ne[0]; // G

const int nsg = op->src[2]->ne[0]/32;

GGML_ASSERT(op->src[5]->type == GGML_TYPE_F32);
GGML_ASSERT(op->ne[0] == ne20 * ne21);
GGML_ASSERT(ne20 % 32 == 0);

snprintf(base, 256, "kernel_gated_delta_net_%s_%d", ggml_type_name(op->src[0]->type), nsg);
snprintf(name, 256, "%s_ne20=%d_ne30=%d", base, ne20, ne30);

ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
ggml_metal_cv_t cv = ggml_metal_cv_init();

ggml_metal_cv_set_int16(cv, ne20, FC_GATED_DELTA_NET + 0);
ggml_metal_cv_set_int16(cv, ne30, FC_GATED_DELTA_NET + 1);

res = ggml_metal_library_compile_pipeline(lib, base, name, cv);

ggml_metal_cv_free(cv);
}

res.nsg = nsg;

return res;
}

ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) {
char base[256];
char name[256];
Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op);
Expand Down
2 changes: 2 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-device.m
Original file line number Diff line number Diff line change
Expand Up @@ -1155,6 +1155,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
return true;
case GGML_OP_GATED_DELTA_NET:
return op->src[2]->ne[0] % 32 == 0;
case GGML_OP_SOLVE_TRI:
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
Expand Down
39 changes: 39 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
#define FC_BIN 1300
#define FC_SUM_ROWS 1400
#define FC_UPSCALE 1500
#define FC_GATED_DELTA_NET 1600

// op-specific constants
#define OP_FLASH_ATTN_EXT_NQPSG 8
Expand Down Expand Up @@ -793,6 +794,44 @@ typedef struct {
uint64_t nb0;
} ggml_metal_kargs_ssm_scan;

typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne10;
int32_t ne11;
int32_t ne12;
int32_t ne13;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ne20;
int32_t ne21;
int32_t ne22;
int32_t ne23;
uint64_t nb20;
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
int32_t ns02;
int32_t ns12;
int32_t ns22;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
} ggml_metal_kargs_gated_delta_net;

typedef struct {
int32_t ne00;
int32_t ne01;
Expand Down
79 changes: 79 additions & 0 deletions ggml/src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_rwkv(ctx, idx);
} break;
case GGML_OP_GATED_DELTA_NET:
{
n_fuse = ggml_metal_op_gated_delta_net(ctx, idx);
} break;
case GGML_OP_SOLVE_TRI:
{
n_fuse = ggml_metal_op_solve_tri(ctx, idx);
Expand Down Expand Up @@ -1562,6 +1566,81 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) {
return 1;
}

int ggml_metal_op_gated_delta_net(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);

ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;


GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne);
GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);

auto pipeline = ggml_metal_library_get_pipeline_gated_delta_net(lib, op);

int ida = 0;

ggml_metal_kargs_gated_delta_net args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne10 =*/ ne10,
/*.ne11 =*/ ne11,
/*.ne12 =*/ ne12,
/*.ne13 =*/ ne13,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.ne20 =*/ ne20,
/*.ne21 =*/ ne21,
/*.ne22 =*/ ne22,
/*.ne23 =*/ ne23,
/*.nb20 =*/ nb20,
/*.nb21 =*/ nb21,
/*.nb22 =*/ nb22,
/*.nb23 =*/ nb23,
/*.ns02 =*/ (int32_t) (nb02/sizeof(float)),
/*.ns12 =*/ (int32_t) (nb12/sizeof(float)),
/*.ns22 =*/ (int32_t) (nb22/sizeof(float)),
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
};

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); // q
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); // k
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); // v
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); // gate
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); // beta
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++); // state
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++); // dst

const int nsg = pipeline.nsg;

ggml_metal_encoder_dispatch_threadgroups(enc, op->src[2]->ne[0]/nsg, op->src[2]->ne[1], op->src[2]->ne[3], 32, nsg, 1);

return 1;
}

int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);

Expand Down
1 change: 1 addition & 0 deletions ggml/src/ggml-metal/ggml-metal-ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_gated_delta_net (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_set (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx);
Expand Down
117 changes: 117 additions & 0 deletions ggml/src/ggml-metal/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -2434,6 +2434,123 @@ kernel void kernel_rwkv_wkv7_f32(
}
}

constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]];
constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]];

template<short NSG>
kernel void kernel_gated_delta_net_impl(
constant ggml_metal_kargs_gated_delta_net & args,
device const char * q,
device const char * k,
device const char * v,
device const char * g,
device const char * b,
device const char * s,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
#define S_v FC_gated_delta_net_ne20
#define G FC_gated_delta_net_ne30

const uint tx = tpitg.x;
const uint ty = tpitg.y;

const uint i23 = tgpig.z; // B
const uint i21 = tgpig.y; // H
const uint i20 = tgpig.x*NSG + ty;

const uint i01 = i21 % args.ne01;
const uint i11 = i21 % args.ne11;

const float scale = 1.0f / sqrt((float)S_v);

device const float * s_ptr = (device const float *) (s) + (i23 * args.ne21*S_v*S_v + i21*S_v*S_v);

float ls[NSG];

FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
ls[j] = s_ptr[is*S_v + i20];
}

device float * dst_attn = (device float *) (dst) + i23*args.ne22*args.ne21*S_v + i21*S_v + i20;

device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01);
device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11);
device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21);

device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21);
device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G;

FOR_UNROLL (short t = 0; t < args.ne22; t++) {
const float beta = b_ptr[0];

float s_k = 0.0f;

if (G == 1) {
const float g_exp = exp(g_ptr[0]);

FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
ls[j] *= g_exp;

s_k += ls[j]*k_ptr[is];
}
} else {
// KDA
FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
ls[j] *= exp(g_ptr[is]);

s_k += ls[j]*k_ptr[is];
}
}

s_k = simd_sum(s_k);

const float d = (v_ptr[i20] - s_k)*beta;

float y = 0.0f;

FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
ls[j] += k_ptr[is]*d;

y += ls[j]*q_ptr[is];
}

y = simd_sum(y);

if (tx == 0) {
dst_attn[t*args.ne21*S_v] = y*scale;
}

q_ptr += args.ns02;
k_ptr += args.ns12;
v_ptr += args.ns22;

b_ptr += args.ne21;
g_ptr += args.ne21*G;
}

device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20;

FOR_UNROLL (short j = 0; j < NSG; j++) {
const short is = tx*NSG + j;
dst_state[is*S_v] = ls[j];
}

#undef S_v
#undef G
}

typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t;

template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<1>;
template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<2>;
template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<4>;

constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]];
constant short FC_solve_tri_n [[function_constant(FC_SOLVE_TRI + 1)]];
constant short FC_solve_tri_k [[function_constant(FC_SOLVE_TRI + 2)]];
Expand Down
Loading