From 982c031bc36ecf0fc4faaaf3bddebf08be4d91fd Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Mon, 25 Feb 2019 10:34:32 -0800 Subject: [PATCH 1/3] fix integer precision loss due to casting to float --- src/operator/math_functions-inl.h | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/operator/math_functions-inl.h b/src/operator/math_functions-inl.h index be5bbe20d036..9a234061e697 100644 --- a/src/operator/math_functions-inl.h +++ b/src/operator/math_functions-inl.h @@ -41,13 +41,18 @@ namespace math { // and returns double #define MXNET_UNARY_MATH_FUNC(name) \ -template MSHADOW_XINLINE \ -float name(DType a) { \ - return ::name##f(static_cast(a)); \ +MSHADOW_XINLINE \ +float name(float a) { \ + return ::name##f(a); \ } \ MSHADOW_XINLINE \ double name(double a) { \ return ::name(a); \ +} \ +template MSHADOW_XINLINE \ +typename std::enable_if::value, double>::type \ +name(DType a) { \ + return ::name(static_cast(a)); \ } #define MXNET_BINARY_MATH_FUNC(name) \ From 574bc64e196184a4024fbf1ff5e4370b35d8b4c7 Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Thu, 5 Sep 2019 14:29:50 -0700 Subject: [PATCH 2/3] add test --- tests/python/unittest/test_ndarray.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index a05b3eae0e44..7091abf7308a 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -1952,6 +1952,25 @@ def test_op(op, num_inputs, mutated_inputs, **kwargs): {'rescale_grad': 0.1, 'lr': 0.01, 'wd': 1e-3}) +def test_large_int_rounding(): + large_integer = 50000001 + + a = mx.nd.array([large_integer], dtype='int32') + assert np.all(a == large_integer) + + a = mx.nd.array([large_integer], dtype='int32').floor() + assert np.all(a == large_integer) + + a = mx.nd.array([large_integer], dtype='int32').round() + assert np.all(a == large_integer) + + a = mx.nd.array([large_integer], dtype='int32').ceil() + assert np.all(a == large_integer) + + a = mx.nd.array([large_integer], dtype='int32').trunc() + assert np.all(a == large_integer) + + if __name__ == '__main__': import nose nose.runmodule() From e44e89d543fa99691ff3b120e0205732c9e6f86b Mon Sep 17 00:00:00 2001 From: Lin Yuan Date: Fri, 6 Sep 2019 13:15:42 -0700 Subject: [PATCH 3/3] update comment --- src/operator/math_functions-inl.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/math_functions-inl.h b/src/operator/math_functions-inl.h index 9a234061e697..45d74a62d8bd 100644 --- a/src/operator/math_functions-inl.h +++ b/src/operator/math_functions-inl.h @@ -35,9 +35,9 @@ namespace op { namespace math { // Wrappers for math.h unary and binary functions -// - For DType != double: math::name(a) does computation in float +// - For DType == float: math::name(a) does computation in float // and returns float -// - For DType == double: math::name(a) does computation in double +// - For DType == double or DType == integer: math::name(a) does computation in double // and returns double #define MXNET_UNARY_MATH_FUNC(name) \