@@ -461,27 +461,15 @@ extern "C" IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
461461 int ny = mm.funcs.size();
462462 while (ny > 0 && !mm.funcs[ny-1]) --ny;
463463 if (ny >= r2) {
464- int nx64 = Nx/64;
465- int nchunk64 = nx64*ne02;
466- for (int ichunk = ith; ichunk < nchunk64; ichunk += nth) {
467- int i02 = ichunk/nx64;
468- int ix = 64*(ichunk - i02*nx64);
464+ nchunk = nx32*ne02;
465+ for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
466+ int i02 = ichunk/nx32;
467+ int ix = 32*(ichunk - i02*nx32);
469468 DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
470- mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 64);
471- }
472- int ix0 = 64*nx64;
473- if (ix0 < Nx) {
474- nx32 -= 2*nx64;
475- nchunk = nx32*ne02;
476- for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
477- int i02 = ichunk/nx32;
478- int ix = ix0 + 32*(ichunk - i02*nx32);
479- DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
480- mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32);
481- }
469+ mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32);
482470 }
471+ return true;
483472 }
484- return true;
485473 }
486474 for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
487475 int i02 = ichunk/nx32;
@@ -494,7 +482,6 @@ extern "C" IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
494482 }
495483 return true;
496484 }
497- //if (ith == 0) printf("Using this: Nx = %d, r2 = %d, ne02 = %d\n", (int)Nx, (int)r2,(int)ne02);
498485 int gcd = simple_gcd(ne02, nth);
499486 int counter = 0;
500487 for (int64_t i12 = 0; i12 < ne02; i12++) {
@@ -510,7 +497,6 @@ extern "C" IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
510497 }
511498
512499 if (ne13 == 1 && ne12 > 1 && ne12 == ne02 && Ny == 1 && nb02 < strideA) {
513- //printf("TG attention gemm for %d heads and Nx = %d\n", (int)ne02, (int)Nx);
514500 MulMat mm;
515501 if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) {
516502 return false;
0 commit comments