Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 6 additions & 20 deletions ggml/src/iqk/iqk_mul_mat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,27 +461,15 @@ extern "C" IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
int ny = mm.funcs.size();
while (ny > 0 && !mm.funcs[ny-1]) --ny;
if (ny >= r2) {
int nx64 = Nx/64;
int nchunk64 = nx64*ne02;
for (int ichunk = ith; ichunk < nchunk64; ichunk += nth) {
int i02 = ichunk/nx64;
int ix = 64*(ichunk - i02*nx64);
nchunk = nx32*ne02;
for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
int i02 = ichunk/nx32;
int ix = 32*(ichunk - i02*nx32);
DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 64);
}
int ix0 = 64*nx64;
if (ix0 < Nx) {
nx32 -= 2*nx64;
nchunk = nx32*ne02;
for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
int i02 = ichunk/nx32;
int ix = ix0 + 32*(ichunk - i02*nx32);
DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32);
}
mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32);
}
return true;
}
return true;
}
for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
int i02 = ichunk/nx32;
Expand All @@ -494,7 +482,6 @@ extern "C" IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
}
return true;
}
//if (ith == 0) printf("Using this: Nx = %d, r2 = %d, ne02 = %d\n", (int)Nx, (int)r2,(int)ne02);
int gcd = simple_gcd(ne02, nth);
int counter = 0;
for (int64_t i12 = 0; i12 < ne02; i12++) {
Expand All @@ -510,7 +497,6 @@ extern "C" IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
}

if (ne13 == 1 && ne12 > 1 && ne12 == ne02 && Ny == 1 && nb02 < strideA) {
//printf("TG attention gemm for %d heads and Nx = %d\n", (int)ne02, (int)Nx);
MulMat mm;
if (!MulMat::prepare(typeA, typeB, ne00, mm, Ny)) {
return false;
Expand Down