Skip to content

Commit 3f8c865

Browse files
ikawrakowIwan Kawrakow
andauthored
Fix standard attention on the CPU (#421)
Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 14ed9fb commit 3f8c865

File tree

1 file changed

+6
-20
lines changed

1 file changed

+6
-20
lines changed

ggml/src/iqk/iqk_mul_mat.cpp

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -461,27 +461,15 @@ extern "C" IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
461461
int ny = mm.funcs.size();
462462
while (ny > 0 && !mm.funcs[ny-1]) --ny;
463463
if (ny >= r2) {
464-
int nx64 = Nx/64;
465-
int nchunk64 = nx64*ne02;
466-
for (int ichunk = ith; ichunk < nchunk64; ichunk += nth) {
467-
int i02 = ichunk/nx64;
468-
int ix = 64*(ichunk - i02*nx64);
464+
nchunk = nx32*ne02;
465+
for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
466+
int i02 = ichunk/nx32;
467+
int ix = 32*(ichunk - i02*nx32);
469468
DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
470-
mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 64);
471-
}
472-
int ix0 = 64*nx64;
473-
if (ix0 < Nx) {
474-
nx32 -= 2*nx64;
475-
nchunk = nx32*ne02;
476-
for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
477-
int i02 = ichunk/nx32;
478-
int ix = ix0 + 32*(ichunk - i02*nx32);
479-
DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
480-
mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32);
481-
}
469+
mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32);
482470
}
471+
return true;
483472
}
484-
return true;
485473
}
486474
for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
487475
int i02 = ichunk/nx32;
@@ -494,7 +482,6 @@ extern "C" IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
494482
}
495483
return true;
496484
}
497-
//if (ith == 0) printf("Using this: Nx = %d, r2 = %d, ne02 = %d\n", (int)Nx, (int)r2,(int)ne02);
498485
int gcd = simple_gcd(ne02, nth);
499486
int counter = 0;
500487
for (int64_t i12 = 0; i12 < ne02; i12++) {
@@ -510,7 +497,6 @@ extern "C" IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
510497
}
511498

512499
if (ne13 == 1 && ne12 > 1 && ne12 == ne02 && Ny == 1 && nb02 < strideA) {
513-
//printf("TG attention gemm for %d heads and Nx = %d\n", (int)ne02, (int)Nx);
514500
MulMat mm;
515501
if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) {
516502
return false;

0 commit comments

Comments
 (0)