Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 4f8bc3a

Browse files
apeforesteric-haibin-lin
authored andcommitted
Fix unary operator ceil/floor/trunc when data type is integer (#14251)
* fix integer precision loss due to casting to float * add test * update comment
1 parent 24f0a10 commit 4f8bc3a

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

src/operator/math_functions-inl.h

+10-5
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,24 @@ namespace op {
3535
namespace math {
3636

3737
// Wrappers for math.h unary and binary functions
38-
// - For DType != double: math::name(a) does computation in float
38+
// - For DType == float: math::name(a) does computation in float
3939
// and returns float
40-
// - For DType == double: math::name(a) does computation in double
40+
// - For DType == double or DType == integer: math::name(a) does computation in double
4141
// and returns double
4242

4343
#define MXNET_UNARY_MATH_FUNC(name) \
44-
template<typename DType> MSHADOW_XINLINE \
45-
float name(DType a) { \
46-
return ::name##f(static_cast<float>(a)); \
44+
MSHADOW_XINLINE \
45+
float name(float a) { \
46+
return ::name##f(a); \
4747
} \
4848
MSHADOW_XINLINE \
4949
double name(double a) { \
5050
return ::name(a); \
51+
} \
52+
template<typename DType> MSHADOW_XINLINE \
53+
typename std::enable_if<std::is_integral<DType>::value, double>::type \
54+
name(DType a) { \
55+
return ::name(static_cast<double>(a)); \
5156
}
5257

5358
#define MXNET_BINARY_MATH_FUNC(name) \

tests/python/unittest/test_ndarray.py

+19
Original file line numberDiff line numberDiff line change
@@ -1952,6 +1952,25 @@ def test_op(op, num_inputs, mutated_inputs, **kwargs):
19521952
{'rescale_grad': 0.1, 'lr': 0.01, 'wd': 1e-3})
19531953

19541954

1955+
def test_large_int_rounding():
1956+
large_integer = 50000001
1957+
1958+
a = mx.nd.array([large_integer], dtype='int32')
1959+
assert np.all(a == large_integer)
1960+
1961+
a = mx.nd.array([large_integer], dtype='int32').floor()
1962+
assert np.all(a == large_integer)
1963+
1964+
a = mx.nd.array([large_integer], dtype='int32').round()
1965+
assert np.all(a == large_integer)
1966+
1967+
a = mx.nd.array([large_integer], dtype='int32').ceil()
1968+
assert np.all(a == large_integer)
1969+
1970+
a = mx.nd.array([large_integer], dtype='int32').trunc()
1971+
assert np.all(a == large_integer)
1972+
1973+
19551974
if __name__ == '__main__':
19561975
import nose
19571976
nose.runmodule()

0 commit comments

Comments
 (0)