@@ -301,6 +301,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
301301 {
302302 n_fuse = ggml_metal_op_glu (ctx, idx);
303303 } break ;
304+ case GGML_OP_SUM:
305+ {
306+ n_fuse = ggml_metal_op_sum (ctx, idx);
307+ } break ;
304308 case GGML_OP_SUM_ROWS:
305309 case GGML_OP_MEAN:
306310 {
@@ -410,6 +414,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
410414 {
411415 n_fuse = ggml_metal_op_argmax (ctx, idx);
412416 } break ;
417+ case GGML_OP_OPT_STEP_ADAMW:
418+ {
419+ n_fuse = ggml_metal_op_opt_step_adamw (ctx, idx);
420+ } break ;
413421 default :
414422 {
415423 GGML_LOG_ERROR (" %s: error: node %3d, op = %8s not implemented\n " , __func__, idx, ggml_op_name (node->op ));
@@ -840,6 +848,30 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) {
840848 return 1 ;
841849}
842850
851+ int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx) {
852+ ggml_tensor * op = ctx->node (idx);
853+
854+ ggml_metal_library_t lib = ctx->lib ;
855+ ggml_metal_encoder_t enc = ctx->enc ;
856+
857+ const uint64_t n = (uint64_t ) ggml_nelements (op->src [0 ]);
858+
859+ ggml_metal_kargs_sum args = {
860+ /* .np =*/ n,
861+ };
862+
863+ ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum (lib, op);
864+
865+ ggml_metal_encoder_set_pipeline (enc, pipeline);
866+ ggml_metal_encoder_set_bytes (enc, &args, sizeof (args), 0 );
867+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), 1 );
868+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op), 2 );
869+
870+ ggml_metal_encoder_dispatch_threadgroups (enc, 1 , 1 , 1 , 1 , 1 , 1 );
871+
872+ return 1 ;
873+ }
874+
843875int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx) {
844876 ggml_tensor * op = ctx->node (idx);
845877
@@ -3401,3 +3433,39 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
34013433
34023434 return 1 ;
34033435}
3436+
3437+ int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx) {
3438+ ggml_tensor * op = ctx->node (idx);
3439+
3440+ ggml_metal_library_t lib = ctx->lib ;
3441+ ggml_metal_encoder_t enc = ctx->enc ;
3442+
3443+ GGML_TENSOR_LOCALS ( int32_t , ne0, op->src [0 ], ne);
3444+ GGML_TENSOR_LOCALS (uint64_t , nb0, op->src [0 ], nb);
3445+ GGML_TENSOR_LOCALS ( int32_t , ne, op, ne);
3446+ GGML_TENSOR_LOCALS (uint32_t , nb, op, nb);
3447+
3448+ ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw (lib, op);
3449+
3450+ const int64_t np = ggml_nelements (op->src [0 ]);
3451+ ggml_metal_kargs_opt_step_adamw args = {
3452+ /* .np =*/ np,
3453+ };
3454+
3455+ int ida = 0 ;
3456+
3457+ ggml_metal_encoder_set_pipeline (enc, pipeline);
3458+ ggml_metal_encoder_set_bytes (enc, &args, sizeof (args), ida++);
3459+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), ida++);
3460+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [1 ]), ida++);
3461+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [2 ]), ida++);
3462+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [3 ]), ida++);
3463+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [4 ]), ida++);
3464+
3465+ const int nth = std::min (ggml_metal_pipeline_max_theads_per_threadgroup (pipeline), ne0);
3466+ const int64_t n = (np + nth - 1 ) / nth;
3467+
3468+ ggml_metal_encoder_dispatch_threadgroups (enc, n, 1 , 1 , nth, 1 , 1 );
3469+
3470+ return 1 ;
3471+ }
0 commit comments