Skip to content

Commit

Permalink
Update zarch SCAL kernels to handle INF and NAN arguments (#4829)
Browse files Browse the repository at this point in the history
* handle INF and NAN in input (for S/D only if DUMMY2 argument is set)
  • Loading branch information
martin-frbg committed Jul 31, 2024
1 parent 136a4ed commit edbf093
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 78 deletions.
60 changes: 46 additions & 14 deletions kernel/zarch/cscal.c
Original file line number Diff line number Diff line change
Expand Up @@ -234,22 +234,38 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
} else {

while (j < n1) {

temp0 = -da_i * x[i + 1];
x[i + 1] = da_i * x[i];
if (isnan(x[i]) || isinf(x[i]))
temp0 = NAN;
else
temp0 = -da_i * x[i + 1];
if (!isinf(x[i + 1]))
x[i + 1] = da_i * x[i];
else
x[i + 1] = NAN;
x[i] = temp0;
temp1 = -da_i * x[i + 1 + inc_x];
x[i + 1 + inc_x] = da_i * x[i + inc_x];
if (isnan(x[i+inc_x]) || isinf(x[i+inc_x]))
temp1 = NAN;
else
temp1 = -da_i * x[i + 1 + inc_x];
if (!isinf(x[i + 1 + inc_x]))
x[i + 1 + inc_x] = da_i * x[i + inc_x];
else
x[i + 1 + inc_x] = NAN;
x[i + inc_x] = temp1;
i += 2 * inc_x;
j += 2;

}

while (j < n) {

temp0 = -da_i * x[i + 1];
x[i + 1] = da_i * x[i];
if (isnan(x[i]) || isinf(x[i]))
temp0 = NAN;
else
temp0 = -da_i * x[i + 1];
if (isinf(x[i + 1]))
x[i + 1] = NAN;
else
x[i + 1] = da_i * x[i];
x[i] = temp0;
i += inc_x;
j++;
Expand Down Expand Up @@ -332,26 +348,42 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
j = n1;
}

if (da_r == 0.0) {
if (da_r == 0.0 || isnan(da_r)) {

if (da_i == 0.0) {

float res = 0.0;
if (isnan(da_r)) res = da_r;
while (j < n) {

x[i] = 0.0;
x[i + 1] = 0.0;
x[i] = res;
x[i + 1] = res;
i += 2;
j++;

}
} else if (isinf(da_r)) {
while(j < n)
{

x[i]= NAN;
x[i+1] = da_r;
i += 2 ;
j++;

}

} else {

while (j < n) {

temp0 = -da_i * x[i + 1];
x[i + 1] = da_i * x[i];
x[i] = temp0;
if (isinf(x[i])) temp0 = NAN;
if (!isinf(x[i + 1]))
x[i + 1] = da_i * x[i];
else
x[i + 1] = NAN;
if (x[i] == x[i])
x[i] = temp0;
i += 2;
j++;

Expand Down
46 changes: 27 additions & 19 deletions kernel/zarch/dscal.c
Original file line number Diff line number Diff line change
Expand Up @@ -96,20 +96,28 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x,
if (inc_x == 1) {

if (da == 0.0) {

BLASLONG n1 = n & -16;
if (n1 > 0) {

dscal_kernel_16_zero(n1, x);
j = n1;

if (dummy2 == 0) {
BLASLONG n1 = n & -16;
if (n1 > 0) {
dscal_kernel_16_zero(n1, x);
j = n1;
}

while (j < n) {
x[j] = 0.0;
j++;
}
} else {
while (j < n) {
if (isfinite(x[j]))
x[j] = 0.0;
else
x[j] = NAN;
j++;
}
}

while (j < n) {

x[j] = 0.0;
j++;
}


} else {

BLASLONG n1 = n & -16;
Expand All @@ -127,23 +135,23 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x,
} else {

if (da == 0.0) {

if (dummy2 == 0) {
BLASLONG n1 = n & -4;

while (j < n1) {

x[i] = 0.0;
x[i + inc_x] = 0.0;
x[i + 2 * inc_x] = 0.0;
x[i + 3 * inc_x] = 0.0;

i += inc_x * 4;
j += 4;

}
}
while (j < n) {

x[i] = 0.0;
if (dummy2==0 || isfinite(x[i]))
x[i] = 0.0;
else
x[i] = NAN;
i += inc_x;
j++;
}
Expand Down
89 changes: 55 additions & 34 deletions kernel/zarch/sscal.c
Original file line number Diff line number Diff line change
Expand Up @@ -95,21 +95,31 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x,

if (inc_x == 1) {

if (da == 0.0) {

BLASLONG n1 = n & -32;
if (n1 > 0) {

sscal_kernel_32_zero(n1, x);
j = n1;
}

while (j < n) {

x[j] = 0.0;
j++;
if (da == 0.0 || !isfinite(da)) {
if (dummy2 == 0) {
BLASLONG n1 = n & -32;
if (n1 > 0) {

sscal_kernel_32_zero(n1, x);
j = n1;
}

while (j < n) {

x[j] = 0.0;
j++;
}
} else {
float res = 0.0;
if (!isfinite(da)) res = NAN;
while (j < n) {
if (isfinite(x[i]))
x[j] = res;
else
x[j] = NAN;
j++;
}
}

} else {

BLASLONG n1 = n & -32;
Expand All @@ -126,26 +136,37 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da, FLOAT *x,

} else {

if (da == 0.0) {

BLASLONG n1 = n & -2;

while (j < n1) {

x[i] = 0.0;
x[i + inc_x] = 0.0;

i += inc_x * 2;
j += 2;

}
while (j < n) {

x[i] = 0.0;
i += inc_x;
j++;
}

if (da == 0.0 || !isfinite(da)) {
if (dummy2 == 0) {
BLASLONG n1 = n & -2;

while (j < n1) {

x[i] = 0.0;
x[i + inc_x] = 0.0;

i += inc_x * 2;
j += 2;

}
while (j < n) {

x[i] = 0.0;
i += inc_x;
j++;
}
} else {
while (j < n) {
float res = 0.0;
if (!isfinite(da)) res = NAN;
if (isfinite(x[i]))
x[i] = res;
else
x[i] = NAN;
i += inc_x;
j++;
}
}
} else {
BLASLONG n1 = n & -2;

Expand Down
43 changes: 32 additions & 11 deletions kernel/zarch/zscal.c
Original file line number Diff line number Diff line change
Expand Up @@ -237,13 +237,19 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
temp0 = NAN;
else
temp0 = -da_i * x[i + 1];
x[i + 1] = da_i * x[i];
if (!isinf(x[i + 1]))
x[i + 1] = da_i * x[i];
else
x[i + 1] = NAN;
x[i] = temp0;
if (isnan(x[i + inc_x]) || isinf(x[i + inc_x]))
temp1 = NAN;
else
temp1 = -da_i * x[i + 1 + inc_x];
x[i + 1 + inc_x] = da_i * x[i + inc_x];
if (!isinf(x[i + 1 + inc_x]))
x[i + 1 + inc_x] = da_i * x[i + inc_x];
else
x[i + 1 + inc_x] = NAN;
x[i + inc_x] = temp1;
i += 2 * inc_x;
j += 2;
Expand All @@ -256,7 +262,10 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
temp0 = NAN;
else
temp0 = -da_i * x[i + 1];
x[i + 1] = da_i * x[i];
if (!isinf(x[i +1]))
x[i + 1] = da_i * x[i];
else
x[i + 1] = NAN;
x[i] = temp0;
i += inc_x;
j++;
Expand Down Expand Up @@ -330,7 +339,7 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
zscal_kernel_8_zero(n1, x);
else
zscal_kernel_8(n1, da_r, da_i, x);
else if (da_i == 0)
else if (da_i == 0 && da_r == da_r)
zscal_kernel_8_zero_i(n1, alpha, x);
else
zscal_kernel_8(n1, da_r, da_i, x);
Expand All @@ -339,29 +348,41 @@ int CNAME(BLASLONG n, BLASLONG dummy0, BLASLONG dummy1, FLOAT da_r, FLOAT da_i,
j = n1;
}

if (da_r == 0.0) {
if (da_r == 0.0 || isnan(da_r)) {

if (da_i == 0.0) {

double res= 0.0;
if (isnan(da_r)) res = da_r;
while (j < n) {

x[i] = 0.0;
x[i + 1] = 0.0;
x[i] = res;
x[i + 1] = res;
i += 2;
j++;

}

} else if (isinf(da_r)) {
while (j < n) {
x[i] = NAN;
x[i + 1] = da_r;
i += 2;
j++;
}
} else {

while (j < n) {

if (isnan(x[i]) || isinf(x[i]))
if (isinf(x[i]))
temp0 = NAN;
else
temp0 = -da_i * x[i + 1];
x[i + 1] = da_i * x[i];
x[i] = temp0;
if (!isinf(x[i + 1]))
x[i + 1] = da_i * x[i];
else
x[i + 1] = NAN;
if (x[i]==x[i])
x[i] = temp0;
i += 2;
j++;

Expand Down

0 comments on commit edbf093

Please sign in to comment.