@@ -322,6 +322,11 @@ struct MulMat {
322322 }
323323};
324324
325+ static std::vector<char > & thread_local_work_buffer () {
326+ thread_local std::vector<char > f;
327+ return f;
328+ }
329+
325330}
326331
327332extern " C" IQK_API bool iqk_mul_mat (long Nx, long Ny, long ne00,
@@ -349,15 +354,15 @@ extern "C" IQK_API bool iqk_mul_mat(long Nx, long Ny, long ne00,
349354
350355 auto type_size = ggml_type_size (dequant_type);
351356
352- thread_local std::vector<char > f;
353-
354357 size_t row_size_qx = ne00*type_size;
355358 size_t row_size_qy = strideB;
356359
357360 // printf("Dequant mul mat %s x %s: ne00 = %d, row_size = %d\n", ggml_type_name(dequant_type), ggml_type_name(ggml_type(typeB)), (int)ne00, (int)row_size_qx);
358361
359362 DataInfo info{C + first_x, (const char *)B, (size_t )stride_C, row_size_qy, 0 , 1 , nullptr , 0 };
360363
364+ auto & f = thread_local_work_buffer ();
365+
361366 for (int ix = 0 ; ix < nrc_x; ix += k_x_step) {
362367 auto this_info = info;
363368 this_info.s += ix;
@@ -501,6 +506,47 @@ extern "C" IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
501506 assert (row_mapping != nullptr );
502507
503508 MulMat mm;
509+
510+ auto etypeA = ggml_type (typeA);
511+ if (auto dequant_type = MulMat::is_dequant_better (etypeA, Ny); dequant_type != etypeA) {
512+ if (!MulMat::prepare (dequant_type, typeB, ne00, mm, Ny)) {
513+ return false ;
514+ }
515+
516+ constexpr int k_x_step = 32 ;
517+
518+ auto num_rows = MulMat::num_rows (ggml_type (dequant_type));
519+ GGML_ASSERT (Nx%num_rows == 0 );
520+ auto nrc_x = (Nx/num_rows + nth - 1 )/nth;
521+ auto first_x = ith*nrc_x;
522+ if (first_x + nrc_x > Nx/num_rows) nrc_x = Nx/num_rows - first_x;
523+ first_x *= num_rows;
524+ nrc_x *= num_rows;
525+
526+ auto type_size = ggml_type_size (dequant_type);
527+
528+ size_t row_size_qx = ne00*type_size;
529+ size_t row_size_qy = strideB;
530+
531+ DataInfo info{C + first_x, (const char *)B, nb1/sizeof (float ), row_size_qy, 0 , ne11, row_mapping, nb2/sizeof (float )};
532+
533+ auto & f = thread_local_work_buffer ();
534+
535+ for (int ix = 0 ; ix < nrc_x; ix += k_x_step) {
536+ auto this_info = info;
537+ this_info.s += ix;
538+ int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
539+ if (f.size () < row_size_qx*this_nrc_x) f.resize (row_size_qx*this_nrc_x);
540+ if (!iqk_dequantize_ktquants (typeA, ne00, (const char *)A + (first_x + ix)*strideA, strideA, f.data (), ne00, this_nrc_x)) {
541+ GGML_ABORT (" Fatal error" );
542+ }
543+ mm.mul_mat_NxM (ne00, f.data (), row_size_qx, this_info, this_nrc_x, Ny);
544+ }
545+
546+ return true ;
547+
548+ }
549+
504550 if (!MulMat::prepare (typeA, typeB, ne00, mm, Ny)) {
505551 return false ;
506552 }
@@ -528,6 +574,52 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int n
528574 assert (row_mapping != nullptr );
529575
530576 MulMat mm;
577+
578+ auto etypeA = ggml_type (typeA);
579+ if (auto dequant_type = MulMat::is_dequant_better (etypeA, Ny); dequant_type != etypeA) {
580+ if (!MulMat::prepare (dequant_type, typeB, ne00, mm, Ny)) {
581+ return false ;
582+ }
583+
584+ constexpr int k_x_step = 64 ;
585+
586+ auto num_rows = MulMat::num_rows (ggml_type (dequant_type));
587+ GGML_ASSERT (Nx%num_rows == 0 );
588+ auto nrc_x = (Nx/num_rows + nth - 1 )/nth;
589+ auto first_x = ith*nrc_x;
590+ if (first_x + nrc_x > Nx/num_rows) nrc_x = Nx/num_rows - first_x;
591+ first_x *= num_rows;
592+ nrc_x *= num_rows;
593+
594+ auto type_size = ggml_type_size (dequant_type);
595+
596+ size_t row_size_qx = ne00*type_size;
597+ size_t row_size_qy = strideB;
598+
599+ DataInfo info{C + first_x, (const char *)B, nb1/sizeof (float ), row_size_qy, 0 , ne11, row_mapping, nb2/sizeof (float )};
600+
601+ auto & f = thread_local_work_buffer ();
602+
603+ for (int ix = 0 ; ix < nrc_x; ix += k_x_step) {
604+ auto this_info = info;
605+ this_info.s += ix;
606+ int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
607+ if (f.size () < 2 *row_size_qx*this_nrc_x) f.resize (2 *row_size_qx*this_nrc_x);
608+ auto Xu = f.data ();
609+ auto Xg = f.data () + row_size_qx*this_nrc_x;
610+ if (!iqk_dequantize_ktquants (typeA, ne00, (const char *)Aup + (first_x + ix)*strideA, strideA, Xu, ne00, this_nrc_x)) {
611+ GGML_ABORT (" Fatal error" );
612+ }
613+ if (!iqk_dequantize_ktquants (typeA, ne00, (const char *)Agate + (first_x + ix)*strideA, strideA, Xg, ne00, this_nrc_x)) {
614+ GGML_ABORT (" Fatal error" );
615+ }
616+ mm.mul_mat_up_gate_NxM (ne00, Xu, Xg, row_size_qx, this_info, this_nrc_x, Ny, unary_op);
617+ }
618+
619+ return true ;
620+
621+ }
622+
531623 if (!MulMat::prepare (typeA, typeB, ne00, mm, Ny)) {
532624 return false ;
533625 }
0 commit comments