Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix NaN value comparisons in relu, max and min ops #14262

Merged
merged 2 commits into from
Mar 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 38 additions & 10 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,6 @@ MXNET_UNARY_MATH_OP(softsign, a / (1.0f + math::fabs(a)));

MXNET_UNARY_MATH_OP(softsign_grad, 1.0f / math::sqr(1.0f + math::fabs(a)));

MXNET_UNARY_MATH_OP_NC(relu, a > DType(0) ? a : DType(0));

MXNET_UNARY_MATH_OP_NC(relu_grad, a > DType(0) ? DType(1) : DType(0));

MXNET_UNARY_MATH_OP_NC(selu, DType(SELU_LAMBDA) *
(a > DType(0) ? a : DType(math::id(SELU_ALPHA) * math::expm1(a))));

Expand Down Expand Up @@ -317,12 +313,6 @@ MXNET_BINARY_MATH_OP(rpower, math::pow(b, a));

MXNET_BINARY_MATH_OP(rpower_grad, math::id(a) * math::log(b));

/*! \brief used for generate element of maximum */
MXNET_BINARY_MATH_OP(maximum, a > b ? a : b);

/*! \brief used for generate element of minimum */
MXNET_BINARY_MATH_OP_NC(minimum, a < b ? a : b);

MXNET_UNARY_MATH_OP_NC(nt, a != DType(0) ? DType(0) : DType(1));

MXNET_BINARY_MATH_OP_NC(ge, a >= b ? DType(1) : DType(0));
Expand Down Expand Up @@ -788,6 +778,44 @@ namespace isnan_typed {
}
}; // namespace isnan_typed

MXNET_UNARY_MATH_OP_NC(relu, isnan_typed::IsNan(a) || (a > DType(0)) ? a : DType(0));

/*! \brief used for computing gradient of relu operator */
struct relu_grad : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a) {
if (isnan_typed::IsNan(a)) {
return a;
} else {
return a > DType(0) ? DType(1) : DType(0);
}
}
};

/*! \brief used for computing binary operator maximum */
struct maximum : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
if (isnan_typed::IsNan(a)) {
return a;
} else {
return (a > b ? a : b);
}
}
};

/*! \brief used for computing binary operator minimum */
struct minimum : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
if (isnan_typed::IsNan(a)) {
return a;
} else {
return DType(a < b ? a : b);
}
}
};

/*! \brief sum reducer that ignores NaN values in the input */
struct nansum {
/*! \brief do reduction into dst */
Expand Down
29 changes: 29 additions & 0 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1550,6 +1550,35 @@ def test_ndarray_is_nan():
np.testing.assert_equal(output.asnumpy(), expected_output.astype(int))
# astype since numpy functions default return type is boolean array instead of int

@with_seed()
def test_ndarray_nan_comparison():
random_dimensions = np.random.randint(2, 5)
random_shape = [np.random.randint(2, 5) for i in range(random_dimensions)]
data1 = mxnet.test_utils.rand_ndarray(random_shape,'default')
data2 = mxnet.test_utils.rand_ndarray(random_shape,'default')
data1[1][0] = np.NaN
data2[0][0] = np.NaN

nd_max = mx.nd.maximum(data1, data2)
np_max = np.maximum(data1.asnumpy(), data2.asnumpy())
np.testing.assert_equal(nd_max.asnumpy(), np_max)

nd_min = mx.nd.minimum(data1, data2)
np_min = np.minimum(data1.asnumpy(), data2.asnumpy())
np.testing.assert_equal(nd_min.asnumpy(), np_min)

nd_relu = mx.nd.relu(data1)
np_relu = np.maximum(data1.asnumpy(), 0)
np.testing.assert_equal(nd_relu.asnumpy(), np_relu)

anirudhacharya marked this conversation as resolved.
Show resolved Hide resolved
data1.attach_grad()
with mx.autograd.record():
y = mx.nd.relu(data1)
y.backward()
data1_grad = data1.grad.asnumpy()
for i in (np.isnan(data1_grad))[1][0].flatten():
assert i == True

if __name__ == '__main__':
import nose
nose.runmodule()