Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fix unary operator ceil/floor/trunc when data type is integer (#14251)
Browse files Browse the repository at this point in the history
* fix integer precision loss due to casting to float

* add test

* update comment
  • Loading branch information
apeforest authored and eric-haibin-lin committed Sep 7, 2019
1 parent 24f0a10 commit 4f8bc3a
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
15 changes: 10 additions & 5 deletions src/operator/math_functions-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,24 @@ namespace op {
namespace math {

// Wrappers for math.h unary and binary functions
// - For DType != double: math::name(a) does computation in float
// - For DType == float: math::name(a) does computation in float
// and returns float
// - For DType == double: math::name(a) does computation in double
// - For DType == double or DType == integer: math::name(a) does computation in double
// and returns double

#define MXNET_UNARY_MATH_FUNC(name) \
template<typename DType> MSHADOW_XINLINE \
float name(DType a) { \
return ::name##f(static_cast<float>(a)); \
MSHADOW_XINLINE \
float name(float a) { \
return ::name##f(a); \
} \
MSHADOW_XINLINE \
double name(double a) { \
return ::name(a); \
} \
template<typename DType> MSHADOW_XINLINE \
typename std::enable_if<std::is_integral<DType>::value, double>::type \
name(DType a) { \
return ::name(static_cast<double>(a)); \
}

#define MXNET_BINARY_MATH_FUNC(name) \
Expand Down
19 changes: 19 additions & 0 deletions tests/python/unittest/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1952,6 +1952,25 @@ def test_op(op, num_inputs, mutated_inputs, **kwargs):
{'rescale_grad': 0.1, 'lr': 0.01, 'wd': 1e-3})


def test_large_int_rounding():
large_integer = 50000001

a = mx.nd.array([large_integer], dtype='int32')
assert np.all(a == large_integer)

a = mx.nd.array([large_integer], dtype='int32').floor()
assert np.all(a == large_integer)

a = mx.nd.array([large_integer], dtype='int32').round()
assert np.all(a == large_integer)

a = mx.nd.array([large_integer], dtype='int32').ceil()
assert np.all(a == large_integer)

a = mx.nd.array([large_integer], dtype='int32').trunc()
assert np.all(a == large_integer)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 4f8bc3a

Please sign in to comment.