diff --git a/src/operator/math_functions-inl.h b/src/operator/math_functions-inl.h index be5bbe20d036..45d74a62d8bd 100644 --- a/src/operator/math_functions-inl.h +++ b/src/operator/math_functions-inl.h @@ -35,19 +35,24 @@ namespace op { namespace math { // Wrappers for math.h unary and binary functions -// - For DType != double: math::name(a) does computation in float +// - For DType == float: math::name(a) does computation in float // and returns float -// - For DType == double: math::name(a) does computation in double +// - For DType == double or DType == integer: math::name(a) does computation in double // and returns double #define MXNET_UNARY_MATH_FUNC(name) \ -template MSHADOW_XINLINE \ -float name(DType a) { \ - return ::name##f(static_cast(a)); \ +MSHADOW_XINLINE \ +float name(float a) { \ + return ::name##f(a); \ } \ MSHADOW_XINLINE \ double name(double a) { \ return ::name(a); \ +} \ +template MSHADOW_XINLINE \ +typename std::enable_if::value, double>::type \ +name(DType a) { \ + return ::name(static_cast(a)); \ } #define MXNET_BINARY_MATH_FUNC(name) \ diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index a05b3eae0e44..7091abf7308a 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -1952,6 +1952,25 @@ def test_op(op, num_inputs, mutated_inputs, **kwargs): {'rescale_grad': 0.1, 'lr': 0.01, 'wd': 1e-3}) +def test_large_int_rounding(): + large_integer = 50000001 + + a = mx.nd.array([large_integer], dtype='int32') + assert np.all(a == large_integer) + + a = mx.nd.array([large_integer], dtype='int32').floor() + assert np.all(a == large_integer) + + a = mx.nd.array([large_integer], dtype='int32').round() + assert np.all(a == large_integer) + + a = mx.nd.array([large_integer], dtype='int32').ceil() + assert np.all(a == large_integer) + + a = mx.nd.array([large_integer], dtype='int32').trunc() + assert np.all(a == large_integer) + + if __name__ == '__main__': import nose nose.runmodule()