Skip to content

Commit 1374062

Browse files
ikawrakowIwan Kawrakow
andauthored
Fix SER (CPU) (#415)
* Fixing SER bugs * Cleanup --------- Co-authored-by: Iwan Kawrakow <[email protected]>
1 parent 0c57f84 commit 1374062

File tree

2 files changed

+51
-24
lines changed

2 files changed

+51
-24
lines changed

ggml/src/ggml.c

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12472,6 +12472,11 @@ static void ggml_compute_forward_sum_rows_f32(
1247212472
float * dst_row = (float *) ((char *) dst->data + i1*nb1 + i2*nb2 + i3*nb3);
1247312473
float row_sum = 0;
1247412474
ggml_vec_sum_f32(ne00, &row_sum, src_row);
12475+
if (!isfinite(row_sum)) {
12476+
fprintf(stderr, "Oops(%s, %s): found %g for i1 = %d, i2 = %d, i3 = %d. ne00 = %d\n", __func__, dst->name,
12477+
(double)row_sum, (int)i1, (int)i2, (int)i3, (int)ne00);
12478+
exit(1);
12479+
}
1247512480
dst_row[0] = row_sum;
1247612481
}
1247712482
}
@@ -14759,6 +14764,18 @@ static void ggml_compute_forward_mul_mat_id(
1475914764

1476014765
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
1476114766

14767+
GGML_ASSERT(ids->ne[1] == dst->ne[2]);
14768+
for (int64_t iid1 = ith; iid1 < ids->ne[1]; iid1 += nth) {
14769+
for (int id = 0; id < n_ids; ++id) {
14770+
const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
14771+
if (i02 < 0 || i02 >= n_as) {
14772+
// This is needed for SER. If fewer experts have been activated for this row, we need to
14773+
// clear it, else there could be garbage that leads to NaNs later on.
14774+
memset((char *)dst->data + id*dst->nb[1] + iid1*dst->nb[2], 0, dst->ne[0]*sizeof(float));
14775+
}
14776+
}
14777+
}
14778+
1476214779
if (ith == 0) {
1476314780
// initialize matrix_row_counts
1476414781
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
@@ -15012,6 +15029,18 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
1501215029

1501315030
#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
1501415031

15032+
GGML_ASSERT(ids->ne[1] == dst->ne[2]);
15033+
for (int64_t iid1 = ith; iid1 < ids->ne[1]; iid1 += nth) {
15034+
for (int id = 0; id < n_ids; ++id) {
15035+
const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
15036+
if (i02 < 0 || i02 >= n_as) {
15037+
// This is needed for SER. If fewer experts have been activated for this row, we need to
15038+
// clear it, else there could be garbage that leads to NaNs later on.
15039+
memset((char *)dst->data + id*dst->nb[1] + iid1*dst->nb[2], 0, dst->ne[0]*sizeof(float));
15040+
}
15041+
}
15042+
}
15043+
1501515044
if (ith == 0) {
1501615045
// initialize matrix_row_counts
1501715046
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
@@ -15916,7 +15945,7 @@ static void ggml_compute_forward_get_rows_f16(
1591615945
(const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
1591715946
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
1591815947
} else {
15919-
memset((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03, 0, nc*sizeof(float));
15948+
memset((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3, 0, nc*sizeof(float));
1592015949
}
1592115950

1592215951
}
@@ -15960,7 +15989,7 @@ static void ggml_compute_forward_get_rows_bf16(
1596015989
(const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
1596115990
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
1596215991
} else {
15963-
memset((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03, 0, nc*sizeof(float));
15992+
memset((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3, 0, nc*sizeof(float));
1596415993
}
1596515994
}
1596615995
}

ggml/src/iqk/iqk_mul_mat.cpp

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -458,31 +458,29 @@ extern "C" IQK_API bool iqk_mul_mat_4d(long Nx, long Ny, long ne00,
458458
if (r2 <= 8) {
459459
MulMat mm;
460460
if (!MulMat::prepare(typeA, typeB, ne00, mm, r2)) return false;
461-
int nx64 = Nx/64;
462-
int nchunk64 = nx64*ne02;
463-
for (int ichunk = ith; ichunk < nchunk64; ichunk += nth) {
464-
int i02 = ichunk/nx64;
465-
int ix = 64*(ichunk - i02*nx64);
466-
DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
467-
mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 64);
468-
}
469-
int ix0 = 64*nx64;
470-
if (ix0 < Nx) {
471-
nx32 -= 2*nx64;
472-
nchunk = nx32*ne02;
473-
for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
474-
int i02 = ichunk/nx32;
475-
int ix = ix0 + 32*(ichunk - i02*nx32);
461+
int ny = mm.funcs.size();
462+
while (ny > 0 && !mm.funcs[ny-1]) --ny;
463+
if (ny >= r2) {
464+
int nx64 = Nx/64;
465+
int nchunk64 = nx64*ne02;
466+
for (int ichunk = ith; ichunk < nchunk64; ichunk += nth) {
467+
int i02 = ichunk/nx64;
468+
int ix = 64*(ichunk - i02*nx64);
476469
DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
477-
mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32);
470+
mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 64);
471+
}
472+
int ix0 = 64*nx64;
473+
if (ix0 < Nx) {
474+
nx32 -= 2*nx64;
475+
nchunk = nx32*ne02;
476+
for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
477+
int i02 = ichunk/nx32;
478+
int ix = ix0 + 32*(ichunk - i02*nx32);
479+
DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
480+
mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32);
481+
}
478482
}
479483
}
480-
//for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {
481-
// int i02 = ichunk/nx32;
482-
// int ix = 32*(ichunk - i02*nx32);
483-
// DataInfo info{C + ix + r2*i02*nb2, (const char *)B + r2*i02*nb12, (size_t)nb2, (size_t)nb12, 0, 1, nullptr, 0};
484-
// mm.funcs[r2-1](ne00, (const void *)((const char *)A + ix*strideA + i02*nb02), strideA, info, 32);
485-
//}
486484
return true;
487485
}
488486
for (int ichunk = ith; ichunk < nchunk; ichunk += nth) {

0 commit comments

Comments
 (0)