Skip to content

Commit a31cf36

Browse files
authored
metal : add opt_step_adamw and op_sum (ggml-org#16529)
* scaffold to support opt step adamw on metal (not written so far) * add opt-step-adamw kernel for metal * pass op->src[4] as a separate buffer to the pipeline * add bounds check to opt-step-adamw kernel * complete scaffold for GGML_OP_SUM * naive GGML_OP_SUM kernel * remove unwanted comment * change OP_SUM capability gate * Add has_simdgroup_reduction to both ops to pass CI
1 parent 81d54bb commit a31cf36

File tree

7 files changed

+172
-0
lines changed

7 files changed

+172
-0
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t l
268268
return res;
269269
}
270270

271+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) {
272+
assert(op->op == GGML_OP_SUM);
273+
274+
char base[256];
275+
char name[256];
276+
277+
snprintf(base, 256, "kernel_op_sum_%s", ggml_type_name(op->src[0]->type));
278+
snprintf(name, 256, "%s", base);
279+
280+
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
281+
if (res) {
282+
return res;
283+
}
284+
285+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
286+
287+
return res;
288+
}
289+
271290
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
272291
GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
273292

@@ -1482,3 +1501,21 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_me
14821501
return res;
14831502
}
14841503

1504+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) {
1505+
assert(op->op == GGML_OP_OPT_STEP_ADAMW);
1506+
1507+
char base[256];
1508+
char name[256];
1509+
1510+
snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type));
1511+
snprintf(name, 256, "%s", base);
1512+
1513+
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1514+
if (res) {
1515+
return res;
1516+
}
1517+
1518+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1519+
1520+
return res;
1521+
}

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_me
109109
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
110110
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
111111
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
112+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op);
112113
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op);
113114
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op);
114115
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op);
@@ -134,6 +135,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_me
134135
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
135136
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
136137
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
138+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
137139

138140
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
139141
ggml_metal_library_t lib,

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
656656
case GGML_OP_COS:
657657
case GGML_OP_LOG:
658658
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
659+
case GGML_OP_SUM:
659660
case GGML_OP_SUM_ROWS:
660661
case GGML_OP_MEAN:
661662
case GGML_OP_SOFT_MAX:
@@ -798,6 +799,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
798799
return false;
799800
};
800801
}
802+
case GGML_OP_OPT_STEP_ADAMW:
803+
return has_simdgroup_reduction;
801804
default:
802805
return false;
803806
}

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,10 @@ typedef struct{
544544
float limit;
545545
} ggml_metal_kargs_glu;
546546

547+
typedef struct {
548+
uint64_t np;
549+
} ggml_metal_kargs_sum;
550+
547551
typedef struct {
548552
int64_t ne00;
549553
int64_t ne01;
@@ -773,4 +777,8 @@ typedef struct {
773777
uint64_t nb01;
774778
} ggml_metal_kargs_argmax;
775779

780+
typedef struct {
781+
int64_t np;
782+
} ggml_metal_kargs_opt_step_adamw;
783+
776784
#endif // GGML_METAL_IMPL

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
301301
{
302302
n_fuse = ggml_metal_op_glu(ctx, idx);
303303
} break;
304+
case GGML_OP_SUM:
305+
{
306+
n_fuse = ggml_metal_op_sum(ctx, idx);
307+
} break;
304308
case GGML_OP_SUM_ROWS:
305309
case GGML_OP_MEAN:
306310
{
@@ -410,6 +414,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
410414
{
411415
n_fuse = ggml_metal_op_argmax(ctx, idx);
412416
} break;
417+
case GGML_OP_OPT_STEP_ADAMW:
418+
{
419+
n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx);
420+
} break;
413421
default:
414422
{
415423
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
@@ -840,6 +848,30 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
840848
return 1;
841849
}
842850

851+
int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) {
852+
ggml_tensor * op = ctx->node(idx);
853+
854+
ggml_metal_library_t lib = ctx->lib;
855+
ggml_metal_encoder_t enc = ctx->enc;
856+
857+
const uint64_t n = (uint64_t) ggml_nelements(op->src[0]);
858+
859+
ggml_metal_kargs_sum args = {
860+
/*.np =*/ n,
861+
};
862+
863+
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op);
864+
865+
ggml_metal_encoder_set_pipeline(enc, pipeline);
866+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
867+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
868+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
869+
870+
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
871+
872+
return 1;
873+
}
874+
843875
int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
844876
ggml_tensor * op = ctx->node(idx);
845877

@@ -3401,3 +3433,39 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
34013433

34023434
return 1;
34033435
}
3436+
3437+
int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
3438+
ggml_tensor * op = ctx->node(idx);
3439+
3440+
ggml_metal_library_t lib = ctx->lib;
3441+
ggml_metal_encoder_t enc = ctx->enc;
3442+
3443+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3444+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3445+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3446+
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3447+
3448+
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op);
3449+
3450+
const int64_t np = ggml_nelements(op->src[0]);
3451+
ggml_metal_kargs_opt_step_adamw args = {
3452+
/*.np =*/ np,
3453+
};
3454+
3455+
int ida = 0;
3456+
3457+
ggml_metal_encoder_set_pipeline(enc, pipeline);
3458+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
3459+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
3460+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
3461+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
3462+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++);
3463+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++);
3464+
3465+
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
3466+
const int64_t n = (np + nth - 1) / nth;
3467+
3468+
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
3469+
3470+
return 1;
3471+
}

ggml/src/ggml-metal/ggml-metal-ops.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx);
5050
int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx);
5151
int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
5252
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
53+
int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx);
5354
int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx);
5455
int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx);
5556
int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx);
@@ -78,6 +79,7 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
7879
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
7980
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
8081
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
82+
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
8183

8284
#ifdef __cplusplus
8385
}

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1723,6 +1723,24 @@ kernel void kernel_geglu_quick_f32(
17231723
}
17241724
}
17251725

1726+
kernel void kernel_op_sum_f32(
1727+
constant ggml_metal_kargs_sum & args,
1728+
device const float * src0,
1729+
device float * dst,
1730+
ushort tiitg[[thread_index_in_threadgroup]]) {
1731+
1732+
if (tiitg != 0) {
1733+
return;
1734+
}
1735+
1736+
float acc = 0.0f;
1737+
for (ulong i = 0; i < args.np; ++i) {
1738+
acc += src0[i];
1739+
}
1740+
1741+
dst[0] = acc;
1742+
}
1743+
17261744
template <bool norm>
17271745
kernel void kernel_sum_rows(
17281746
constant ggml_metal_kargs_sum_rows & args,
@@ -8754,3 +8772,37 @@ kernel void kernel_pool_2d_avg_f32(
87548772

87558773
o_ptr[cur_oh * args.OW + cur_ow] = res;
87568774
}
8775+
8776+
kernel void kernel_opt_step_adamw_f32(
8777+
constant ggml_metal_kargs_opt_step_adamw & args,
8778+
device float * x,
8779+
device const float * g,
8780+
device float * g_m,
8781+
device float * g_v,
8782+
device const float * pars,
8783+
uint gid[[thread_position_in_grid]]) {
8784+
8785+
if (gid >= args.np) {
8786+
return;
8787+
}
8788+
8789+
const float alpha = pars[0];
8790+
const float beta1 = pars[1];
8791+
const float beta2 = pars[2];
8792+
const float eps = pars[3];
8793+
const float wd = pars[4];
8794+
const float beta1h = pars[5];
8795+
const float beta2h = pars[6];
8796+
8797+
const float gi = g[gid];
8798+
const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1);
8799+
const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2);
8800+
8801+
g_m[gid] = gmi;
8802+
g_v[gid] = gvi;
8803+
8804+
const float mh = gmi * beta1h;
8805+
const float vh = sqrt(gvi * beta2h) + eps;
8806+
8807+
x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh;
8808+
}

0 commit comments

Comments
 (0)