@@ -202,6 +202,15 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
202202    GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
203203    GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
204204    GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
205+     GGML_METAL_KERNEL_TYPE_SET_ROWS_F32,
206+     GGML_METAL_KERNEL_TYPE_SET_ROWS_F16,
207+     GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16,
208+     GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0,
209+     GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0,
210+     GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1,
211+     GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,
212+     GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
213+     GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
205214    GGML_METAL_KERNEL_TYPE_RMS_NORM,
206215    GGML_METAL_KERNEL_TYPE_L2_NORM,
207216    GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -1169,6 +1178,15 @@ @implementation GGMLMetalClass
11691178        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,                 get_rows_iq4_nl,                 true );
11701179        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,                 get_rows_iq4_xs,                 true );
11711180        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,                    get_rows_i32,                    true );
1181+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_F32,                    set_rows_f32,                    true );
1182+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_F16,                    set_rows_f16,                    true );
1183+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16,                   set_rows_bf16,                   use_bfloat);
1184+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0,                   set_rows_q8_0,                   true );
1185+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0,                   set_rows_q4_0,                   true );
1186+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1,                   set_rows_q4_1,                   true );
1187+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,                   set_rows_q5_0,                   true );
1188+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,                   set_rows_q5_1,                   true );
1189+         GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,                 set_rows_iq4_nl,                 true );
11721190        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_RMS_NORM,                        rms_norm,                        has_simdgroup_reduction);
11731191        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_L2_NORM,                         l2_norm,                         has_simdgroup_reduction);
11741192        GGML_METAL_ADD_KERNEL (GGML_METAL_KERNEL_TYPE_GROUP_NORM,                      group_norm,                      has_simdgroup_reduction);
@@ -1635,6 +1653,10 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
16351653    const  bool  use_bfloat              = ctx_dev->use_bfloat ;
16361654
16371655    if  (!use_bfloat) {
1656+         if  (op->type  == GGML_TYPE_BF16) {
1657+             return  false ;
1658+         }
1659+ 
16381660        for  (size_t  i = 0 , n = 3 ; i < n; ++i) {
16391661            if  (op->src [i] != NULL  && op->src [i]->type  == GGML_TYPE_BF16) {
16401662                return  false ;
@@ -1804,6 +1826,27 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
18041826            {
18051827                return  op->ne [3 ] == 1 ;
18061828            }
1829+         case  GGML_OP_SET_ROWS:
1830+             {
1831+                 if  (op->src [0 ]->type  != GGML_TYPE_F32) {
1832+                     return  false ;
1833+                 }
1834+ 
1835+                 switch  (op->type ) {
1836+                     case  GGML_TYPE_F32:
1837+                     case  GGML_TYPE_F16:
1838+                     case  GGML_TYPE_BF16:
1839+                     case  GGML_TYPE_Q8_0:
1840+                     case  GGML_TYPE_Q4_0:
1841+                     case  GGML_TYPE_Q4_1:
1842+                     case  GGML_TYPE_Q5_0:
1843+                     case  GGML_TYPE_Q5_1:
1844+                     case  GGML_TYPE_IQ4_NL:
1845+                         return  true ;
1846+                     default :
1847+                         return  false ;
1848+                 };
1849+             }
18071850        default :
18081851            return  false ;
18091852    }
@@ -3777,13 +3820,74 @@ static bool ggml_metal_encode_node(
37773820                };
37783821
37793822                [encoder setComputePipelineState: pipeline];
3780-                 [encoder setBuffer: id_src0       offset: offs_src0  atIndex: 0 ];
3781-                 [encoder setBuffer: id_src1       offset: offs_src1  atIndex: 1 ];
3782-                 [encoder setBuffer: id_dst       offset: offs_dst   atIndex: 2 ];
3783-                 [encoder setBytes: &args  length: sizeof (args)  atIndex: 3 ];
3823+                 [encoder setBytes: &args     length: sizeof (args)  atIndex: 0 ];
3824+                 [encoder setBuffer: id_src0  offset: offs_src0     atIndex: 1 ];
3825+                 [encoder setBuffer: id_src1  offset: offs_src1     atIndex: 2 ];
3826+                 [encoder setBuffer: id_dst   offset: offs_dst      atIndex: 3 ];
37843827
37853828                [encoder dispatchThreadgroups: MTLSizeMake (ne10, ne11, 1 ) threadsPerThreadgroup: MTLSizeMake (32 , 1 , 1 )];
37863829            } break ;
3830+         case  GGML_OP_SET_ROWS:
3831+             {
3832+                 id <MTLComputePipelineState > pipeline = nil ;
3833+ 
3834+                 switch  (dst->type ) {
3835+                     case  GGML_TYPE_F32:    pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_F32   ].pipeline ; break ;
3836+                     case  GGML_TYPE_F16:    pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_F16   ].pipeline ; break ;
3837+                     case  GGML_TYPE_BF16:   pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16  ].pipeline ; break ;
3838+                     case  GGML_TYPE_Q8_0:   pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0  ].pipeline ; break ;
3839+                     case  GGML_TYPE_Q4_0:   pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0  ].pipeline ; break ;
3840+                     case  GGML_TYPE_Q4_1:   pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1  ].pipeline ; break ;
3841+                     case  GGML_TYPE_Q5_0:   pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0  ].pipeline ; break ;
3842+                     case  GGML_TYPE_Q5_1:   pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1  ].pipeline ; break ;
3843+                     case  GGML_TYPE_IQ4_NL: pipeline = ctx->kernels [GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL].pipeline ; break ;
3844+                     default : GGML_ABORT (" not implemented"  );
3845+                 }
3846+ 
3847+                 const  int32_t  nk0 = ne0/ggml_blck_size (dst->type );
3848+ 
3849+                 int  nth = 32 ; //  SIMD width
3850+ 
3851+                 while  (nth < nk0 && nth < (int ) pipeline.maxTotalThreadsPerThreadgroup ) {
3852+                     nth *= 2 ;
3853+                 }
3854+ 
3855+                 int  nrptg = 1 ;
3856+                 if  (nth > nk0) {
3857+                     nrptg = (nth + nk0 - 1 )/nk0;
3858+                     nth   = nk0;
3859+ 
3860+                     if  (nrptg*nth > (int ) pipeline.maxTotalThreadsPerThreadgroup ) {
3861+                         nrptg--;
3862+                     }
3863+                 }
3864+ 
3865+                 nth = MIN (nth, nk0);
3866+ 
3867+                 ggml_metal_kargs_set_rows args = {
3868+                     /* .nk0  =*/   nk0,
3869+                     /* .ne01 =*/   ne01,
3870+                     /* .nb01 =*/   nb01,
3871+                     /* .nb02 =*/   nb02,
3872+                     /* .nb03 =*/   nb03,
3873+                     /* .ne11 =*/   ne11,
3874+                     /* .ne12 =*/   ne12,
3875+                     /* .nb10 =*/   nb10,
3876+                     /* .nb11 =*/   nb11,
3877+                     /* .nb12 =*/   nb12,
3878+                     /* .nb1  =*/   nb1,
3879+                     /* .nb2  =*/   nb2,
3880+                     /* .nb3  =*/   nb3,
3881+                 };
3882+ 
3883+                 [encoder setComputePipelineState: pipeline];
3884+                 [encoder setBytes: &args    length: sizeof (args) atIndex: 0 ];
3885+                 [encoder setBuffer: id_src0 offset: offs_src0    atIndex: 1 ];
3886+                 [encoder setBuffer: id_src1 offset: offs_src1    atIndex: 2 ];
3887+                 [encoder setBuffer: id_dst  offset: offs_dst     atIndex: 3 ];
3888+ 
3889+                 [encoder dispatchThreadgroups: MTLSizeMake ((ne01 + nrptg - 1 )/nrptg, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, nrptg, 1 )];
3890+             } break ;
37873891        case  GGML_OP_RMS_NORM:
37883892            {
37893893                GGML_ASSERT (ne00 % 4  == 0 );
0 commit comments