Skip to content

Commit 0b10f74

Browse files
ikawrakowIwan Kawrakow
andauthored
Faster CPU prompt processing for Trellis quants and MoE models (#488)
* Also do the dequantize approach for mul_mat_id * Also do the dequantize approach for iqk_moe_fused_up_gate --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 7e79665 commit 0b10f74

File tree

1 file changed

+94
-2
lines changed

1 file changed

+94
-2
lines changed

ggml/src/iqk/iqk_mul_mat.cpp

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,11 @@ struct MulMat {
322322
}
323323
};
324324

325+
static std::vector<char> & thread_local_work_buffer() {
326+
thread_local std::vector<char> f;
327+
return f;
328+
}
329+
325330
}
326331

327332
extern "C" IQK_API bool iqk_mul_mat(long Nx, long Ny, long ne00,
@@ -349,15 +354,15 @@ extern "C" IQK_API bool iqk_mul_mat(long Nx, long Ny, long ne00,
349354

350355
auto type_size = ggml_type_size(dequant_type);
351356

352-
thread_local std::vector<char> f;
353-
354357
size_t row_size_qx = ne00*type_size;
355358
size_t row_size_qy = strideB;
356359

357360
//printf("Dequant mul mat %s x %s: ne00 = %d, row_size = %d\n", ggml_type_name(dequant_type), ggml_type_name(ggml_type(typeB)), (int)ne00, (int)row_size_qx);
358361

359362
DataInfo info{C + first_x, (const char *)B, (size_t)stride_C, row_size_qy, 0, 1, nullptr, 0};
360363

364+
auto& f = thread_local_work_buffer();
365+
361366
for (int ix = 0; ix < nrc_x; ix += k_x_step) {
362367
auto this_info = info;
363368
this_info.s += ix;
@@ -501,6 +506,47 @@ extern "C" IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
501506
assert(row_mapping != nullptr);
502507

503508
MulMat mm;
509+
510+
auto etypeA = ggml_type(typeA);
511+
if (auto dequant_type = MulMat::is_dequant_better(etypeA, Ny); dequant_type != etypeA) {
512+
if (!MulMat::prepare(dequant_type, typeB, ne00, mm, Ny)) {
513+
return false;
514+
}
515+
516+
constexpr int k_x_step = 32;
517+
518+
auto num_rows = MulMat::num_rows(ggml_type(dequant_type));
519+
GGML_ASSERT(Nx%num_rows == 0);
520+
auto nrc_x = (Nx/num_rows + nth - 1)/nth;
521+
auto first_x = ith*nrc_x;
522+
if (first_x + nrc_x > Nx/num_rows) nrc_x = Nx/num_rows - first_x;
523+
first_x *= num_rows;
524+
nrc_x *= num_rows;
525+
526+
auto type_size = ggml_type_size(dequant_type);
527+
528+
size_t row_size_qx = ne00*type_size;
529+
size_t row_size_qy = strideB;
530+
531+
DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)};
532+
533+
auto& f = thread_local_work_buffer();
534+
535+
for (int ix = 0; ix < nrc_x; ix += k_x_step) {
536+
auto this_info = info;
537+
this_info.s += ix;
538+
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
539+
if (f.size() < row_size_qx*this_nrc_x) f.resize(row_size_qx*this_nrc_x);
540+
if (!iqk_dequantize_ktquants(typeA, ne00, (const char *)A + (first_x + ix)*strideA, strideA, f.data(), ne00, this_nrc_x)) {
541+
GGML_ABORT("Fatal error");
542+
}
543+
mm.mul_mat_NxM(ne00, f.data(), row_size_qx, this_info, this_nrc_x, Ny);
544+
}
545+
546+
return true;
547+
548+
}
549+
504550
if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) {
505551
return false;
506552
}
@@ -528,6 +574,52 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int n
528574
assert(row_mapping != nullptr);
529575

530576
MulMat mm;
577+
578+
auto etypeA = ggml_type(typeA);
579+
if (auto dequant_type = MulMat::is_dequant_better(etypeA, Ny); dequant_type != etypeA) {
580+
if (!MulMat::prepare(dequant_type, typeB, ne00, mm, Ny)) {
581+
return false;
582+
}
583+
584+
constexpr int k_x_step = 64;
585+
586+
auto num_rows = MulMat::num_rows(ggml_type(dequant_type));
587+
GGML_ASSERT(Nx%num_rows == 0);
588+
auto nrc_x = (Nx/num_rows + nth - 1)/nth;
589+
auto first_x = ith*nrc_x;
590+
if (first_x + nrc_x > Nx/num_rows) nrc_x = Nx/num_rows - first_x;
591+
first_x *= num_rows;
592+
nrc_x *= num_rows;
593+
594+
auto type_size = ggml_type_size(dequant_type);
595+
596+
size_t row_size_qx = ne00*type_size;
597+
size_t row_size_qy = strideB;
598+
599+
DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float), row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)};
600+
601+
auto& f = thread_local_work_buffer();
602+
603+
for (int ix = 0; ix < nrc_x; ix += k_x_step) {
604+
auto this_info = info;
605+
this_info.s += ix;
606+
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
607+
if (f.size() < 2*row_size_qx*this_nrc_x) f.resize(2*row_size_qx*this_nrc_x);
608+
auto Xu = f.data();
609+
auto Xg = f.data() + row_size_qx*this_nrc_x;
610+
if (!iqk_dequantize_ktquants(typeA, ne00, (const char *)Aup + (first_x + ix)*strideA, strideA, Xu, ne00, this_nrc_x)) {
611+
GGML_ABORT("Fatal error");
612+
}
613+
if (!iqk_dequantize_ktquants(typeA, ne00, (const char *)Agate + (first_x + ix)*strideA, strideA, Xg, ne00, this_nrc_x)) {
614+
GGML_ABORT("Fatal error");
615+
}
616+
mm.mul_mat_up_gate_NxM(ne00, Xu, Xg, row_size_qx, this_info, this_nrc_x, Ny, unary_op);
617+
}
618+
619+
return true;
620+
621+
}
622+
531623
if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) {
532624
return false;
533625
}

0 commit comments

Comments
 (0)