diff --git a/test/compare_sgemm_sbgemm.c b/test/compare_sgemm_sbgemm.c index cd508a0cf4..b8aaee8be3 100644 --- a/test/compare_sgemm_sbgemm.c +++ b/test/compare_sgemm_sbgemm.c @@ -85,6 +85,14 @@ float16to32 (bfloat16_bits f16) #define SBGEMM_LARGEST 256 +void *malloc_safe(size_t size) +{ + if (size == 0) + return malloc(1); + else + return malloc(size); +} + int main (int argc, char *argv[]) { @@ -100,13 +108,13 @@ main (int argc, char *argv[]) { if ((x > 100) && (x != SBGEMM_LARGEST)) continue; m = k = n = x; - float *A = (float *)malloc(m * k * sizeof(FLOAT)); - float *B = (float *)malloc(k * n * sizeof(FLOAT)); - float *C = (float *)malloc(m * n * sizeof(FLOAT)); - bfloat16_bits *AA = (bfloat16_bits *)malloc(m * k * sizeof(bfloat16_bits)); - bfloat16_bits *BB = (bfloat16_bits *)malloc(k * n * sizeof(bfloat16_bits)); - float *DD = (float *)malloc(m * n * sizeof(FLOAT)); - float *CC = (float *)malloc(m * n * sizeof(FLOAT)); + float *A = (float *)malloc_safe(m * k * sizeof(FLOAT)); + float *B = (float *)malloc_safe(k * n * sizeof(FLOAT)); + float *C = (float *)malloc_safe(m * n * sizeof(FLOAT)); + bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(m * k * sizeof(bfloat16_bits)); + bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(k * n * sizeof(bfloat16_bits)); + float *DD = (float *)malloc_safe(m * n * sizeof(FLOAT)); + float *CC = (float *)malloc_safe(m * n * sizeof(FLOAT)); if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || (DD == NULL) || (CC == NULL)) return 1; @@ -194,16 +202,16 @@ main (int argc, char *argv[]) return ret; } - k = 1; for (x = 1; x <= loop; x++) { - float *A = (float *)malloc(x * x * sizeof(FLOAT)); - float *B = (float *)malloc(x * sizeof(FLOAT)); - float *C = (float *)malloc(x * sizeof(FLOAT)); - bfloat16_bits *AA = (bfloat16_bits *)malloc(x * x * sizeof(bfloat16_bits)); - bfloat16_bits *BB = (bfloat16_bits *)malloc(x * sizeof(bfloat16_bits)); - float *DD = (float *)malloc(x * sizeof(FLOAT)); - float *CC = (float *)malloc(x * sizeof(FLOAT)); + k = (x == 0) ? 0 : 1; + float *A = (float *)malloc_safe(x * x * sizeof(FLOAT)); + float *B = (float *)malloc_safe(x * sizeof(FLOAT)); + float *C = (float *)malloc_safe(x * sizeof(FLOAT)); + bfloat16_bits *AA = (bfloat16_bits *)malloc_safe(x * x * sizeof(bfloat16_bits)); + bfloat16_bits *BB = (bfloat16_bits *)malloc_safe(x * sizeof(bfloat16_bits)); + float *DD = (float *)malloc_safe(x * sizeof(FLOAT)); + float *CC = (float *)malloc_safe(x * sizeof(FLOAT)); if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) || (DD == NULL) || (CC == NULL)) return 1;