diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index d324128c893..ffa272b8da7 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -10468,7 +10468,7 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( const float * v_d = (const float *)((const char *)src_v->data + iv3 * nbv3 + t * nbv2 + iv1 * nbv1); const float beta_val = *(const float *)((const char *)src_beta->data + iv3 * nbb3 + t * nbb2 + iv1 * nbb1); - const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1); + const float * g_d = (const float *)((const char *)src_g->data + iv3 * nbg3 + t * nbg2 + iv1 * nbg1); if (kda) { for (int64_t i = 0; i < S_v; ++i) { @@ -10501,7 +10501,6 @@ static void ggml_compute_forward_gated_delta_net_one_chunk( attn_data += S_v * H; // advance to next token } - } } diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 169c63dd7a4..15ae2e517df 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -577,6 +577,41 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv(ggml_metal_ return res; } +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net(ggml_metal_library_t lib, const ggml_tensor * op) { + char base[256]; + char name[256]; + + // v is src[2], dimensions: S_v = ne[0], H = ne[1] + const int ne20 = op->src[2]->ne[0]; // S_v + const int ne21 = op->src[2]->ne[1]; // H + const int ne30 = op->src[3]->ne[0]; // G + + const int nsg = op->src[2]->ne[0]/32; + + GGML_ASSERT(op->src[5]->type == GGML_TYPE_F32); + GGML_ASSERT(op->ne[0] == ne20 * ne21); + GGML_ASSERT(ne20 % 32 == 0); + + snprintf(base, 256, "kernel_gated_delta_net_%s_%d", ggml_type_name(op->src[0]->type), nsg); + snprintf(name, 256, "%s_ne20=%d_ne30=%d", base, ne20, ne30); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, ne20, FC_GATED_DELTA_NET + 0); + ggml_metal_cv_set_int16(cv, ne30, FC_GATED_DELTA_NET + 1); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + } + + res.nsg = nsg; + + return res; +} + ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri(ggml_metal_library_t lib, const ggml_tensor * op) { char base[256]; char name[256]; diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 93d7f6a216f..fd2b3ddeb55 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -125,6 +125,7 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_conv_batched (ggml_metal_library_t lib, const struct ggml_tensor * op, int ssm_conv_bs); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_ssm_scan (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_rwkv (ggml_metal_library_t lib, const struct ggml_tensor * op); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_gated_delta_net (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_solve_tri (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mv_ext (ggml_metal_library_t lib, enum ggml_type tsrc0, enum ggml_type tsrc1, int nsg, int nxpsg, int r1ptg); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_mul_mm (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 23bd2b2ab72..a4b176841ce 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1155,6 +1155,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV7: return true; + case GGML_OP_GATED_DELTA_NET: + return op->src[2]->ne[0] % 32 == 0; case GGML_OP_SOLVE_TRI: case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index bf51055e367..82dc2c728a1 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -84,6 +84,7 @@ #define FC_BIN 1300 #define FC_SUM_ROWS 1400 #define FC_UPSCALE 1500 +#define FC_GATED_DELTA_NET 1600 // op-specific constants #define OP_FLASH_ATTN_EXT_NQPSG 8 @@ -793,6 +794,44 @@ typedef struct { uint64_t nb0; } ggml_metal_kargs_ssm_scan; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne20; + int32_t ne21; + int32_t ne22; + int32_t ne23; + uint64_t nb20; + uint64_t nb21; + uint64_t nb22; + uint64_t nb23; + int32_t ns02; + int32_t ns12; + int32_t ns22; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_gated_delta_net; + typedef struct { int32_t ne00; int32_t ne01; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 267755d08cc..306dbcf3660 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -333,6 +333,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_rwkv(ctx, idx); } break; + case GGML_OP_GATED_DELTA_NET: + { + n_fuse = ggml_metal_op_gated_delta_net(ctx, idx); + } break; case GGML_OP_SOLVE_TRI: { n_fuse = ggml_metal_op_solve_tri(ctx, idx); @@ -1562,6 +1566,81 @@ int ggml_metal_op_rwkv(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_gated_delta_net(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + auto pipeline = ggml_metal_library_get_pipeline_gated_delta_net(lib, op); + + int ida = 0; + + ggml_metal_kargs_gated_delta_net args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne20 =*/ ne20, + /*.ne21 =*/ ne21, + /*.ne22 =*/ ne22, + /*.ne23 =*/ ne23, + /*.nb20 =*/ nb20, + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb23 =*/ nb23, + /*.ns02 =*/ (int32_t) (nb02/sizeof(float)), + /*.ns12 =*/ (int32_t) (nb12/sizeof(float)), + /*.ns22 =*/ (int32_t) (nb22/sizeof(float)), + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); // q + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); // k + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); // v + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); // gate + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); // beta + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[5]), ida++); // state + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), ida++); // dst + + const int nsg = pipeline.nsg; + + ggml_metal_encoder_dispatch_threadgroups(enc, op->src[2]->ne[0]/nsg, op->src[2]->ne[1], op->src[2]->ne[3], 32, nsg, 1); + + return 1; +} + int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index f3e38c7aa9d..019f2fec9ed 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -58,6 +58,7 @@ int ggml_metal_op_soft_max (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_gated_delta_net (ggml_metal_op_t ctx, int idx); int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx); int ggml_metal_op_set (ggml_metal_op_t ctx, int idx); int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 82ebbb4e409..6d2f41909ae 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -2434,6 +2434,227 @@ kernel void kernel_rwkv_wkv7_f32( } } +constant short FC_gated_delta_net_ne20 [[function_constant(FC_GATED_DELTA_NET + 0)]]; +constant short FC_gated_delta_net_ne30 [[function_constant(FC_GATED_DELTA_NET + 1)]]; + +#if 1 +template +kernel void kernel_gated_delta_net_impl( + constant ggml_metal_kargs_gated_delta_net & args, + device const char * q, + device const char * k, + device const char * v, + device const char * g, + device const char * b, + device const char * s, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { +#define S_v FC_gated_delta_net_ne20 +#define G FC_gated_delta_net_ne30 + + const uint tx = tpitg.x; + const uint ty = tpitg.y; + + const uint i23 = tgpig.z; // B + const uint i21 = tgpig.y; // H + const uint i20 = tgpig.x*NSG + ty; + + const uint i01 = i21 % args.ne01; + const uint i11 = i21 % args.ne11; + + const float scale = 1.0f / sqrt((float)S_v); + + device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20; + + float ls[NSG]; + + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + ls[j] = s_ptr[is*S_v]; + } + + device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20; + + device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01); + device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11); + device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21); + + device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21); + device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G; + + for (short t = 0; t < args.ne22; t++) { + float s_k = 0.0f; + + if (G == 1) { + const float g_exp = exp(g_ptr[0]); + + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + ls[j] *= g_exp; + + s_k += ls[j]*k_ptr[is]; + } + } else { + // KDA + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + ls[j] *= exp(g_ptr[is]); + + s_k += ls[j]*k_ptr[is]; + } + } + + s_k = simd_sum(s_k); + + const float d = (v_ptr[i20] - s_k)*b_ptr[0]; + + float y = 0.0f; + + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + ls[j] += k_ptr[is]*d; + + y += ls[j]*q_ptr[is]; + } + + y = simd_sum(y); + + if (tx == 0) { + dst_attn[t*args.ne21*S_v] = y*scale; + } + + q_ptr += args.ns02; + k_ptr += args.ns12; + v_ptr += args.ns22; + + b_ptr += args.ne21; + g_ptr += args.ne21*G; + } + + device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20; + + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + dst_state[is*S_v] = ls[j]; + } + +#undef S_v +#undef G +} + +typedef decltype(kernel_gated_delta_net_impl<4>) kernel_gated_delta_net_t; + +template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<1>; +template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<2>; +template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl<4>; + +#else +// a simplified version of the above +// no performance improvement, so keep the above version for now + +template +kernel void kernel_gated_delta_net_impl( + constant ggml_metal_kargs_gated_delta_net & args, + device const char * q, + device const char * k, + device const char * v, + device const char * g, + device const char * b, + device const char * s, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { +#define S_v FC_gated_delta_net_ne20 +#define G FC_gated_delta_net_ne30 + + const uint tx = tpitg.x; + const uint ty = tpitg.y; + + const uint i23 = tgpig.z; // B + const uint i21 = tgpig.y; // H + const uint i20 = tgpig.x*NSG + ty; + + const uint i01 = i21 % args.ne01; + const uint i11 = i21 % args.ne11; + + const float scale = 1.0f / sqrt((float)S_v); + + device const float * s_ptr = (device const float *) (s) + (i23*args.ne21 + i21)*S_v*S_v + i20; + + float lsf[NSG]; + + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + lsf[j] = s_ptr[is*S_v]; + } + + thread T * ls = (thread T *) (lsf); + + device float * dst_attn = (device float *) (dst) + (i23*args.ne22*args.ne21 + i21)*S_v + i20; + + device const float * q_ptr = (device const float *) (q + i23*args.nb03 + i01*args.nb01); + device const float * k_ptr = (device const float *) (k + i23*args.nb13 + i11*args.nb11); + device const float * v_ptr = (device const float *) (v + i23*args.nb23 + i21*args.nb21); + + device const float * b_ptr = (device const float *) (b) + (i23*args.ne22*args.ne21 + i21); + device const float * g_ptr = (device const float *) (g) + (i23*args.ne22*args.ne21 + i21)*G; + + for (short t = 0; t < args.ne22; t++) { + device const T * qt_ptr = (device const T *) (q_ptr); + device const T * kt_ptr = (device const T *) (k_ptr); + device const T * gt_ptr = (device const T *) (g_ptr); + + if (G == 1) { + *ls *= exp(g_ptr[0]); + } else { + // KDA + *ls *= exp(gt_ptr[tx]); + } + + const float s_k = simd_sum(dot(*ls, kt_ptr[tx])); + + const float d = (v_ptr[i20] - s_k)*b_ptr[0]; + + *ls += kt_ptr[tx]*d; + + const float y = simd_sum(dot(*ls, qt_ptr[tx])); + + if (tx == 0) { + *dst_attn = y*scale; + } + + q_ptr += args.ns02; + k_ptr += args.ns12; + v_ptr += args.ns22; + + b_ptr += args.ne21; + g_ptr += args.ne21*G; + + dst_attn += args.ne21*S_v; + } + + device float * dst_state = (device float *) (dst) + args.ne23*args.ne22*args.ne21*S_v + (i23*args.ne21 + i21)*S_v*S_v + i20; + device T * dstt_state = (device T *) (dst_state); + + FOR_UNROLL (short j = 0; j < NSG; j++) { + const short is = tx*NSG + j; + dst_state[is*S_v] = lsf[j]; + } + +#undef S_v +#undef G +} + +typedef decltype(kernel_gated_delta_net_impl) kernel_gated_delta_net_t; + +template [[host_name("kernel_gated_delta_net_f32_1")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl; +template [[host_name("kernel_gated_delta_net_f32_2")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl; +template [[host_name("kernel_gated_delta_net_f32_4")]] kernel kernel_gated_delta_net_t kernel_gated_delta_net_impl; +#endif + constant short FC_solve_tri_nsg [[function_constant(FC_SOLVE_TRI + 0)]]; constant short FC_solve_tri_n [[function_constant(FC_SOLVE_TRI + 1)]]; constant short FC_solve_tri_k [[function_constant(FC_SOLVE_TRI + 2)]];