100100 GGML_METAL_DECL_KERNEL (mul_mm_q4_K_f32);
101101 GGML_METAL_DECL_KERNEL (mul_mm_q5_K_f32);
102102 GGML_METAL_DECL_KERNEL (mul_mm_q6_K_f32);
103- GGML_METAL_DECL_KERNEL (rope);
103+ GGML_METAL_DECL_KERNEL (rope_f32);
104+ GGML_METAL_DECL_KERNEL (rope_f16);
104105 GGML_METAL_DECL_KERNEL (alibi_f32);
105106 GGML_METAL_DECL_KERNEL (cpy_f32_f16);
106107 GGML_METAL_DECL_KERNEL (cpy_f32_f32);
@@ -261,7 +262,8 @@ @implementation GGMLMetalClass
261262 GGML_METAL_ADD_KERNEL (mul_mm_q4_K_f32);
262263 GGML_METAL_ADD_KERNEL (mul_mm_q5_K_f32);
263264 GGML_METAL_ADD_KERNEL (mul_mm_q6_K_f32);
264- GGML_METAL_ADD_KERNEL (rope);
265+ GGML_METAL_ADD_KERNEL (rope_f32);
266+ GGML_METAL_ADD_KERNEL (rope_f16);
265267 GGML_METAL_ADD_KERNEL (alibi_f32);
266268 GGML_METAL_ADD_KERNEL (cpy_f32_f16);
267269 GGML_METAL_ADD_KERNEL (cpy_f32_f32);
@@ -335,7 +337,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
335337 GGML_METAL_DEL_KERNEL (mul_mm_q4_K_f32);
336338 GGML_METAL_DEL_KERNEL (mul_mm_q5_K_f32);
337339 GGML_METAL_DEL_KERNEL (mul_mm_q6_K_f32);
338- GGML_METAL_DEL_KERNEL (rope);
340+ GGML_METAL_DEL_KERNEL (rope_f32);
341+ GGML_METAL_DEL_KERNEL (rope_f16);
339342 GGML_METAL_DEL_KERNEL (alibi_f32);
340343 GGML_METAL_DEL_KERNEL (cpy_f32_f16);
341344 GGML_METAL_DEL_KERNEL (cpy_f32_f32);
@@ -870,7 +873,7 @@ void ggml_metal_graph_compute(
870873 } break ;
871874 case GGML_OP_SOFT_MAX:
872875 {
873- const int nth = 32 ;
876+ const int nth = MIN ( 32 , ne00) ;
874877
875878 if (ne00%4 == 0 ) {
876879 [encoder setComputePipelineState: ctx->pipeline_soft_max_4];
@@ -1134,7 +1137,7 @@ void ggml_metal_graph_compute(
11341137 float eps;
11351138 memcpy (&eps, dst->op_params , sizeof (float ));
11361139
1137- const int nth = 512 ;
1140+ const int nth = MIN ( 512 , ne00) ;
11381141
11391142 [encoder setComputePipelineState: ctx->pipeline_rms_norm];
11401143 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
@@ -1153,7 +1156,7 @@ void ggml_metal_graph_compute(
11531156 float eps;
11541157 memcpy (&eps, dst->op_params , sizeof (float ));
11551158
1156- const int nth = 256 ;
1159+ const int nth = MIN ( 256 , ne00) ;
11571160
11581161 [encoder setComputePipelineState: ctx->pipeline_norm];
11591162 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
@@ -1171,6 +1174,8 @@ void ggml_metal_graph_compute(
11711174 {
11721175 GGML_ASSERT ((src0t == GGML_TYPE_F32));
11731176
1177+ const int nth = MIN (1024 , ne00);
1178+
11741179 const int n_past = ((int32_t *) dst->op_params )[0 ]; UNUSED (n_past);
11751180 const int n_head = ((int32_t *) dst->op_params )[1 ];
11761181 float max_bias;
@@ -1204,15 +1209,15 @@ void ggml_metal_graph_compute(
12041209 [encoder setBytes: &nb3 length: sizeof (uint64_t ) atIndex: 17 ];
12051210 [encoder setBytes: &m0 length: sizeof ( float ) atIndex: 18 ];
12061211
1207- const int nth = 32 ;
1208-
12091212 [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth, 1 , 1 )];
12101213 } break ;
12111214 case GGML_OP_ROPE:
12121215 {
12131216 GGML_ASSERT (ne10 == ne02);
12141217
1215- // const int n_past = ((int32_t *) dst->op_params)[0];
1218+ const int nth = MIN (1024 , ne00);
1219+
1220+ const int n_past = ((int32_t *) dst->op_params )[0 ];
12161221 const int n_dims = ((int32_t *) dst->op_params )[1 ];
12171222 const int mode = ((int32_t *) dst->op_params )[2 ];
12181223
@@ -1221,7 +1226,12 @@ void ggml_metal_graph_compute(
12211226 memcpy (&freq_base, (int32_t *) dst->op_params + 4 , sizeof (float ));
12221227 memcpy (&freq_scale, (int32_t *) dst->op_params + 5 , sizeof (float ));
12231228
1224- [encoder setComputePipelineState: ctx->pipeline_rope];
1229+ switch (src0->type ) {
1230+ case GGML_TYPE_F32: [encoder setComputePipelineState: ctx->pipeline_rope_f32]; break ;
1231+ case GGML_TYPE_F16: [encoder setComputePipelineState: ctx->pipeline_rope_f16]; break ;
1232+ default : GGML_ASSERT (false );
1233+ };
1234+
12251235 [encoder setBuffer: id_src0 offset: offs_src0 atIndex: 0 ];
12261236 [encoder setBuffer: id_src1 offset: offs_src1 atIndex: 1 ];
12271237 [encoder setBuffer: id_dst offset: offs_dst atIndex: 2 ];
@@ -1241,19 +1251,19 @@ void ggml_metal_graph_compute(
12411251 [encoder setBytes: &nb1 length: sizeof (uint64_t ) atIndex: 16 ];
12421252 [encoder setBytes: &nb2 length: sizeof (uint64_t ) atIndex: 17 ];
12431253 [encoder setBytes: &nb3 length: sizeof (uint64_t ) atIndex: 18 ];
1244- // [encoder setBytes:&n_past length:sizeof( int) atIndex:19];
1254+ [encoder setBytes: &n_past length: sizeof ( int ) atIndex: 19 ];
12451255 [encoder setBytes: &n_dims length: sizeof ( int ) atIndex: 20 ];
12461256 [encoder setBytes: &mode length: sizeof ( int ) atIndex: 21 ];
12471257 [encoder setBytes: &freq_base length: sizeof (float ) atIndex: 22 ];
12481258 [encoder setBytes: &freq_scale length: sizeof (float ) atIndex: 23 ];
12491259
1250- [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (32 , 1 , 1 )];
1260+ [encoder dispatchThreadgroups: MTLSizeMake (ne01, ne02, ne03) threadsPerThreadgroup: MTLSizeMake (nth , 1 , 1 )];
12511261 } break ;
12521262 case GGML_OP_DUP:
12531263 case GGML_OP_CPY:
12541264 case GGML_OP_CONT:
12551265 {
1256- const int nth = 32 ;
1266+ const int nth = MIN ( 1024 , ne00) ;
12571267
12581268 switch (src0t) {
12591269 case GGML_TYPE_F32:
0 commit comments