@@ -418,6 +418,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
418418 {
419419 n_fuse = ggml_metal_op_opt_step_adamw (ctx, idx);
420420 } break ;
421+ case GGML_OP_OPT_STEP_SGD:
422+ {
423+ n_fuse = ggml_metal_op_opt_step_sgd (ctx, idx);
424+ } break ;
421425 default :
422426 {
423427 GGML_LOG_ERROR (" %s: error: node %3d, op = %8s not implemented\n " , __func__, idx, ggml_op_name (node->op ));
@@ -3469,3 +3473,37 @@ int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
34693473
34703474 return 1 ;
34713475}
3476+
3477+ int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx) {
3478+ ggml_tensor * op = ctx->node (idx);
3479+
3480+ ggml_metal_library_t lib = ctx->lib ;
3481+ ggml_metal_encoder_t enc = ctx->enc ;
3482+
3483+ GGML_TENSOR_LOCALS ( int32_t , ne0, op->src [0 ], ne);
3484+ GGML_TENSOR_LOCALS (uint64_t , nb0, op->src [0 ], nb);
3485+ GGML_TENSOR_LOCALS ( int32_t , ne, op, ne);
3486+ GGML_TENSOR_LOCALS (uint32_t , nb, op, nb);
3487+
3488+ ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd (lib, op);
3489+
3490+ const int64_t np = ggml_nelements (op->src [0 ]);
3491+ ggml_metal_kargs_opt_step_sgd args = {
3492+ /* .np =*/ np,
3493+ };
3494+
3495+ int ida = 0 ;
3496+
3497+ ggml_metal_encoder_set_pipeline (enc, pipeline);
3498+ ggml_metal_encoder_set_bytes (enc, &args, sizeof (args), ida++);
3499+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [0 ]), ida++);
3500+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [1 ]), ida++);
3501+ ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id (op->src [2 ]), ida++);
3502+
3503+ const int nth = std::min (ggml_metal_pipeline_max_theads_per_threadgroup (pipeline), ne0);
3504+ const int64_t n = (np + nth - 1 ) / nth;
3505+
3506+ ggml_metal_encoder_dispatch_threadgroups (enc, n, 1 , 1 , nth, 1 , 1 );
3507+
3508+ return 1 ;
3509+ }
0 commit comments