Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
nan comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Feb 27, 2019
1 parent 0af40f7 commit 1d05575
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 10 deletions.
38 changes: 28 additions & 10 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,10 +127,6 @@ MXNET_UNARY_MATH_OP(softsign, a / (1.0f + math::fabs(a)));

MXNET_UNARY_MATH_OP(softsign_grad, 1.0f / math::sqr(1.0f + math::fabs(a)));

MXNET_UNARY_MATH_OP_NC(relu, a > DType(0) ? a : DType(0));

MXNET_UNARY_MATH_OP_NC(relu_grad, a > DType(0) ? DType(1) : DType(0));

MXNET_UNARY_MATH_OP_NC(selu, DType(SELU_LAMBDA) *
(a > DType(0) ? a : DType(math::id(SELU_ALPHA) * math::expm1(a))));

Expand Down Expand Up @@ -317,12 +313,6 @@ MXNET_BINARY_MATH_OP(rpower, math::pow(b, a));

MXNET_BINARY_MATH_OP(rpower_grad, math::id(a) * math::log(b));

/*! \brief used for generate element of maximum */
MXNET_BINARY_MATH_OP(maximum, a > b ? a : b);

/*! \brief used for generate element of minimum */
MXNET_BINARY_MATH_OP_NC(minimum, a < b ? a : b);

MXNET_UNARY_MATH_OP_NC(nt, a != DType(0) ? DType(0) : DType(1));

MXNET_BINARY_MATH_OP_NC(ge, a >= b ? DType(1) : DType(0));
Expand Down Expand Up @@ -788,6 +778,34 @@ namespace isnan_typed {
}
}; // namespace isnan_typed

MXNET_UNARY_MATH_OP_NC(relu, isnan_typed::IsNan(a) || (a > DType(0)) ? a : DType(0));

MXNET_UNARY_MATH_OP_NC(relu_grad, a > DType(0) ? DType(1) : DType(0));

/*! \brief used for binary operator maximum */
struct maximum : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
if (isnan_typed::IsNan(a)) {
return a;
} else {
return (a > b ? a : b);
}
}
};

/*! \brief used for binary operator minimum */
struct minimum : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static DType Map(DType a, DType b) {
if (isnan_typed::IsNan(a)) {
return a;
} else {
return DType(a < b ? a : b);
}
}
};

/*! \brief sum reducer that ignores NaN values in the input */
struct nansum {
/*! \brief do reduction into dst */
Expand Down
21 changes: 21 additions & 0 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1550,6 +1550,27 @@ def test_ndarray_is_nan():
np.testing.assert_equal(output.asnumpy(), expected_output.astype(int))
# astype since numpy functions default return type is boolean array instead of int

@with_seed()
def test_ndarray_nan_comparison():
random_dimensions = np.random.randint(2, 5)
random_shape = [np.random.randint(2, 5) for i in range(random_dimensions)]
data1 = mxnet.test_utils.rand_ndarray(random_shape,'default')
data2 = mxnet.test_utils.rand_ndarray(random_shape,'default')
data1[1][0] = np.NaN
data2[0][0] = np.NaN

nd_max = mx.nd.maximum(data1, data2)
np_max = np.maximum(data1.asnumpy(), data2.asnumpy())
np.testing.assert_equal(nd_max.asnumpy(), np_max)

nd_min = mx.nd.minimum(data1, data2)
np_min = np.minimum(data1.asnumpy(), data2.asnumpy())
np.testing.assert_equal(nd_min.asnumpy(), np_min)

nd_relu = mx.nd.relu(data1)
np_relu = np.maximum(data1.asnumpy(), 0)
np.testing.assert_equal(nd_relu.asnumpy(), np_relu)

if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 1d05575

Please sign in to comment.