diff --git a/3rdparty/mshadow/mshadow/base.h b/3rdparty/mshadow/mshadow/base.h index ec798337e7d0..5f477e52801d 100755 --- a/3rdparty/mshadow/mshadow/base.h +++ b/3rdparty/mshadow/mshadow/base.h @@ -606,6 +606,64 @@ struct divto { typedef op::div OPType; }; } // namespace sv + +#ifndef __CUDA_ARCH__ +using std::isnan; +using std::isinf; +#endif + +/*! \brief + * determines if the given floating point + * number is not a number */ +namespace isnan_typed { + template + MSHADOW_XINLINE bool IsNan(volatile DType val) { + return false; + } + template<> + MSHADOW_XINLINE bool IsNan(volatile float val) { + return isnan(val); + } + template<> + MSHADOW_XINLINE bool IsNan(volatile double val) { + return isnan(val); + } + template<> + MSHADOW_XINLINE bool IsNan(volatile long double val) { + return isnan(val); + } + template<> + MSHADOW_XINLINE bool IsNan(volatile mshadow::half::half_t val) { + return (val.half_ & (~MSHADOW_HALF_SIGN_BIT)) > MSHADOW_HALF_EXPONENT_BITS; + } +} // namespace isnan_typed + +/*! \brief + * determines if the given floating point + * number is a positive or negative infinity */ +namespace isinf_typed { + template + MSHADOW_XINLINE bool IsInf(volatile DType val) { + return false; + } + template<> + MSHADOW_XINLINE bool IsInf(volatile float val) { + return isinf(val); + } + template<> + MSHADOW_XINLINE bool IsInf(volatile double val) { + return isinf(val); + } + template<> + MSHADOW_XINLINE bool IsInf(volatile long double val) { + return isinf(val); + } + template<> + MSHADOW_XINLINE bool IsInf(volatile mshadow::half::half_t val) { + return (val.half_ & (~MSHADOW_HALF_SIGN_BIT)) == MSHADOW_HALF_EXPONENT_BITS; + } +} // namespace isinf_typed + /*! \brief namespace for potential reducer operations */ namespace red { namespace limits { @@ -674,6 +732,12 @@ template<> MSHADOW_XINLINE double NegInfValue(void) { return -HUGE_VAL; } +/*! \brief negative infinity value of float16 */ +template<> +MSHADOW_XINLINE half::half_t NegInfValue(void) { + return half::half_t::Binary( + MSHADOW_HALF_SIGN_BIT | MSHADOW_HALF_EXPONENT_BITS); +} /*! * \brief maximum value of certain types @@ -740,6 +804,11 @@ template<> MSHADOW_XINLINE double PosInfValue(void) { return HUGE_VAL; } +/*! \brief positive infinity value of float16 */ +template<> +MSHADOW_XINLINE half::half_t PosInfValue(void) { + return half::half_t::Binary(MSHADOW_HALF_EXPONENT_BITS); +} } // namespace limits @@ -755,7 +824,11 @@ struct sum { MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType& residual) { // NOLINT(*) DType y = src - residual; DType t = dst + y; - residual = (t - dst) - y; + if (isinf_typed::IsInf(t)) { + residual = 0; + } else { + residual = (t - dst) - y; + } dst = t; } /*! \brief combine the results of two reducers */ @@ -767,10 +840,15 @@ struct sum { template MSHADOW_XINLINE static void Merge(volatile DType& dst_val, volatile DType& dst_residual, volatile DType& src_val, volatile DType& src_residual) { // NOLINT(*) DType t1 = dst_val + src_val; - DType e = t1 - dst_val; - DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual; - dst_val = t1 + t2; - dst_residual = t2 - (dst_val - t1); + if (isinf_typed::IsInf(t1)) { + dst_val = t1; + dst_residual = 0; + } else { + DType e = t1 - dst_val; + DType t2 = ((src_val - e) + (dst_val - (t1 - e))) + dst_residual + src_residual; + dst_val = t1 + t2; + dst_residual = t2 - (dst_val - t1); + } } /*! \brief finalize reduction */ template @@ -807,12 +885,9 @@ struct maximum { /*! \brief do reduction into dst */ template MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*) - using namespace std; -#ifdef __CUDACC__ - dst = ::max(dst, src); -#else - dst = max(dst, src); -#endif // __CUDACC__ + if (!isnan_typed::IsNan(dst)) { + if (!(dst >= src)) dst = src; + } } /*! \brief do reduction into dst */ template @@ -863,12 +938,9 @@ struct minimum { /*! \brief do reduction into dst */ template MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*) - using namespace std; -#ifdef __CUDACC__ - dst = ::min(dst, src); -#else - dst = min(dst, src); -#endif // __CUDACC__ + if (!isnan_typed::IsNan(dst)) { + if (!(dst <= src)) dst = src; + } } /*! \brief do reduction into dst */ template diff --git a/3rdparty/mshadow/mshadow/extension/reduce_with_axis.h b/3rdparty/mshadow/mshadow/extension/reduce_with_axis.h index 54bcc750cfc5..26b6156ad6f9 100644 --- a/3rdparty/mshadow/mshadow/extension/reduce_with_axis.h +++ b/3rdparty/mshadow/mshadow/extension/reduce_with_axis.h @@ -112,7 +112,7 @@ struct Plan, DTy index_t z = (x*size_+k)*trailing_+y; DType tmp = res; Reducer::Reduce(res, src_.Eval(z/last_, z%last_)); - if (tmp != res) { + if (tmp != res && !isnan_typed::IsNan(tmp)) { idx = k; } } diff --git a/3rdparty/mshadow/mshadow/half.h b/3rdparty/mshadow/mshadow/half.h index 2dded0a7752e..1cc53ae0460f 100644 --- a/3rdparty/mshadow/mshadow/half.h +++ b/3rdparty/mshadow/mshadow/half.h @@ -349,6 +349,8 @@ MSHADOW_HALF_OPERATOR(bool, <=) #define MSHADOW_HALF_MIN mshadow::half::half_t::Binary(0xFBFF); #define MSHADOW_HALF_MAX mshadow::half::half_t::Binary(0x7BFF); +#define MSHADOW_HALF_SIGN_BIT 0x8000 +#define MSHADOW_HALF_EXPONENT_BITS 0x7c00 } // namespace half } // namespace mshadow #endif // MSHADOW_HALF_H_ diff --git a/julia/src/ndarray/reduction.jl b/julia/src/ndarray/reduction.jl index 833b483ca321..2045ce231674 100644 --- a/julia/src/ndarray/reduction.jl +++ b/julia/src/ndarray/reduction.jl @@ -47,8 +47,7 @@ broadcasted(::typeof(min), x::NDArray{T}, y::NDArray{T}) where {T} = """ argmax(x::NDArray; dims) -> indices -Note that `NaN` is skipped during comparison. -This is different from Julia `Base.argmax`. +Note that `NaN` is treated as greater than all other values in `argmax`. ## Examples @@ -77,8 +76,7 @@ Base.argmax(x::NDArray; dims = :) = _argmax(x, dims) .+ 1 """ argmin(x::NDArray; dims) -> indices -Note that `NaN` is skipped during comparison. -This is different from Julia `Base.argmin`. +Note that `NaN` is treated as less than all other values in `argmin`. ## Examples diff --git a/julia/test/unittest/ndarray.jl b/julia/test/unittest/ndarray.jl index 638963f1b8aa..599b0a65bfc4 100644 --- a/julia/test/unittest/ndarray.jl +++ b/julia/test/unittest/ndarray.jl @@ -1515,8 +1515,8 @@ function test_argmax() 4 2 6] x = NDArray(A) - @test copy(argmax(x, dims = 1)) == [2 1 2] - @test copy(argmax(x, dims = 2)) == reshape([2, 3], :, 1) + @test copy(argmax(x, dims = 1)) == [x[1] for x ∈ argmax(A, dims = 1)] + @test copy(argmax(x, dims = 2)) == [x[2] for x ∈ argmax(A, dims = 2)] end @info "NDArray::argmax::NaN" @@ -1525,8 +1525,8 @@ function test_argmax() NaN 2 6] x = NDArray(A) - @test copy(argmax(x, dims = 1)) == [1 1 2] - @test copy(argmax(x, dims = 2)) == reshape([2, 3], :, 1) + @test copy(argmax(x, dims = 1)) == [x[1] for x ∈ argmax(A, dims = 1)] + @test copy(argmax(x, dims = 2)) == [x[2] for x ∈ argmax(A, dims = 2)] end end @@ -1537,8 +1537,8 @@ function test_argmin() 4 2 6] x = NDArray(A) - @test copy(argmin(x, dims = 1)) == [1 2 1] - @test copy(argmin(x, dims = 2)) == reshape([1, 2], :, 1) + @test copy(argmin(x, dims = 1)) == [x[1] for x ∈ argmin(A, dims = 1)] + @test copy(argmin(x, dims = 2)) == [x[2] for x ∈ argmin(A, dims = 2)] end @info "NDArray::argmin::NaN" @@ -1547,8 +1547,8 @@ function test_argmin() NaN 2 6] x = NDArray(A) - @test copy(argmin(x, dims = 1)) == [1 2 1] - @test copy(argmin(x, dims = 2)) == reshape([1, 2], :, 1) + @test copy(argmin(x, dims = 1)) == [x[1] for x ∈ argmin(A, dims = 1)] + @test copy(argmin(x, dims = 2)) == [x[2] for x ∈ argmin(A, dims = 2)] end end diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index f656daea3016..7d8cc524c817 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -4935,6 +4935,7 @@ class DLDataType(ctypes.Structure): "bool": (1, 1, 1), "uint32": (1, 32, 1), "uint64": (1, 64, 1), + 'float16': (2, 16, 1), "float32": (2, 32, 1), "float64": (2, 64, 1), } diff --git a/src/operator/contrib/allclose_op-inl.h b/src/operator/contrib/allclose_op-inl.h index a10c7795e568..c54b2630924c 100644 --- a/src/operator/contrib/allclose_op-inl.h +++ b/src/operator/contrib/allclose_op-inl.h @@ -84,7 +84,7 @@ inline bool AllCloseType(const nnvm::NodeAttrs& attrs, return (*out_attrs)[0] != -1; } -using namespace mshadow_op::isnan_typed; +using mshadow::isnan_typed::IsNan; template struct allclose_forward { diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index f9d4e23a3a6f..4ae587188d1b 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -27,6 +27,7 @@ #define MXNET_OPERATOR_MSHADOW_OP_H_ #include +#include #include "math.h" #include "math_functions-inl.h" #include "special_functions-inl.h" @@ -41,6 +42,9 @@ namespace mxnet { namespace op { namespace mshadow_op { +using mshadow::isnan_typed::IsNan; +using mshadow::isinf_typed::IsInf; + #ifdef __CUDA_ARCH__ __constant__ const float PI = 3.14159265358979323846; __constant__ const float SELU_ALPHA = 1.6732632423543772848170429916717; @@ -51,8 +55,6 @@ const float PI = 3.14159265358979323846; const float SELU_ALPHA = 1.6732632423543772848170429916717; const float SELU_LAMBDA = 1.0507009873554804934193349852946; const float SQRT_2 = 1.4142135623730950488016887242096; -using std::isnan; -using std::isinf; #endif using std::enable_if; using std::is_unsigned; @@ -1001,61 +1003,13 @@ struct product { } }; -namespace isnan_typed { - template - MSHADOW_XINLINE bool IsNan(volatile DType val) { - return false; - } - template<> - MSHADOW_XINLINE bool IsNan(volatile float val) { - return isnan(val); - } - template<> - MSHADOW_XINLINE bool IsNan(volatile double val) { - return isnan(val); - } - template<> - MSHADOW_XINLINE bool IsNan(volatile long double val) { - return isnan(val); - } - - template<> - MSHADOW_XINLINE bool IsNan(volatile mshadow::half::half_t val) { - return (val.half_ & 0x7fff) > 0x7c00; - } -}; // namespace isnan_typed - -namespace isinf_typed { - template - MSHADOW_XINLINE bool IsInf(volatile DType val) { - return false; - } - template<> - MSHADOW_XINLINE bool IsInf(volatile float val) { - return isinf(val); - } - template<> - MSHADOW_XINLINE bool IsInf(volatile double val) { - return isinf(val); - } - template<> - MSHADOW_XINLINE bool IsInf(volatile long double val) { - return isinf(val); - } - - template<> - MSHADOW_XINLINE bool IsInf(volatile mshadow::half::half_t val) { - return (val.half_ & 0x7fff) >= 0x7c00; - } -}; // namespace isinf_typed - -MXNET_UNARY_MATH_OP_NC(relu, isnan_typed::IsNan(a) || (a > DType(0)) ? a : DType(0)); +MXNET_UNARY_MATH_OP_NC(relu, IsNan(a) || (a > DType(0)) ? a : DType(0)); /*! \brief used for computing gradient of relu operator */ struct relu_grad : public mxnet_op::tunable { template MSHADOW_XINLINE static DType Map(DType a) { - if (isnan_typed::IsNan(a)) { + if (IsNan(a)) { return a; } else { return a > DType(0) ? DType(1) : DType(0); @@ -1067,7 +1021,7 @@ struct relu_grad : public mxnet_op::tunable { struct maximum : public mxnet_op::tunable { template MSHADOW_XINLINE static DType Map(DType a, DType b) { - if (isnan_typed::IsNan(a)) { + if (IsNan(a)) { return a; } else { return (a > b ? a : b); @@ -1079,7 +1033,7 @@ struct maximum : public mxnet_op::tunable { struct minimum : public mxnet_op::tunable { template MSHADOW_XINLINE static DType Map(DType a, DType b) { - if (isnan_typed::IsNan(a)) { + if (IsNan(a)) { return a; } else { return DType(a < b ? a : b); @@ -1092,13 +1046,13 @@ struct nansum { /*! \brief do reduction into dst */ template MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*) - if (isnan_typed::IsNan(src)) return; + if (IsNan(src)) return; dst += src; } /*! \brief do reduction into dst */ template MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src, volatile DType& residual) { // NOLINT(*) - if (isnan_typed::IsNan(src)) return; + if (IsNan(src)) return; DType y = src - residual; DType t = dst + y; residual = (t - dst) - y; @@ -1144,7 +1098,7 @@ struct nansum { struct nansum_grad : public mxnet_op::tunable { template MSHADOW_XINLINE static DType Map(DType a, DType b) { - return isnan_typed::IsNan(a) ? DType(0) : DType(1); + return IsNan(a) ? DType(0) : DType(1); } }; @@ -1153,7 +1107,7 @@ struct nanprod { /*! \brief do reduction into dst */ template MSHADOW_XINLINE static void Reduce(volatile DType& dst, volatile DType src) { // NOLINT(*) - if (isnan_typed::IsNan(src)) return; + if (IsNan(src)) return; dst *= src; } /*! \brief do reduction into dst */ @@ -1327,7 +1281,7 @@ struct sum { struct nanprod_grad : public mxnet_op::tunable { template MSHADOW_XINLINE static DType Map(DType a, DType b) { - return isnan_typed::IsNan(a) ? DType(0) : b / a; + return IsNan(a) ? DType(0) : b / a; } }; diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index 27013dfb98ae..577c994a8ee1 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -699,9 +699,9 @@ struct nan_to_num_forward { const DType posinf, const DType neginf) { DType val = in_data[i]; - if (mshadow_op::isnan_typed::IsNan(val)) val = nan; - if (val > 0 && mshadow_op::isinf_typed::IsInf(val)) val = posinf; - if (val < 0 && mshadow_op::isinf_typed::IsInf(val)) val = neginf; + if (mshadow_op::IsNan(val)) val = nan; + if (val > 0 && mshadow_op::IsInf(val)) val = posinf; + if (val < 0 && mshadow_op::IsInf(val)) val = neginf; KERNEL_ASSIGN(out_data[i], req, val); } }; @@ -758,9 +758,9 @@ struct nan_to_num_backward { const DType* out_grad, const DType* in_data) { DType val = out_grad[i]; - if (mshadow_op::isnan_typed::IsNan(in_data[i])) val = 0; - if (val > 0 && mshadow_op::isinf_typed::IsInf(in_data[i])) val = 0; - if (val < 0 && mshadow_op::isinf_typed::IsInf(in_data[i])) val = 0; + if (mshadow_op::IsNan(in_data[i])) val = 0; + if (val > 0 && mshadow_op::IsInf(in_data[i])) val = 0; + if (val < 0 && mshadow_op::IsInf(in_data[i])) val = 0; KERNEL_ASSIGN(in_grad[i], req, val); } }; diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 9b3ce3192b9e..4c6d9f7b8ef2 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -21,6 +21,7 @@ from itertools import permutations, combinations_with_replacement import os import pickle as pkl +import random import functools from nose.tools import assert_raises, raises from common import with_seed, assertRaises, TemporaryDirectory @@ -31,7 +32,7 @@ from mxnet.test_utils import same from mxnet.test_utils import random_sample, rand_shape_nd, random_arrays from mxnet import runtime -from numpy.testing import assert_allclose +from numpy.testing import assert_allclose, assert_array_equal, assert_array_almost_equal import mxnet.autograd from mxnet.base import integer_types from mxnet.ndarray.ndarray import py_slice @@ -580,13 +581,40 @@ def test_dot(): @with_seed() def test_reduce(): - sample_num = 200 - def test_reduce_inner(numpy_reduce_func, nd_reduce_func, multi_axes): + sample_num = 300 + def test_reduce_inner(numpy_reduce_func, nd_reduce_func, multi_axes, + allow_almost_equal=False, check_dtype=True): + dtypes = [(np.float16, 1), + (np.float32, 5), + (np.double, 6)] for i in range(sample_num): + dtype, decimal = random.choice(dtypes) ndim = np.random.randint(1, 6) shape = np.random.randint(1, 11, size=ndim) - dat = np.random.rand(*shape) - 0.5 + dat = (np.random.rand(*shape) - 0.5).astype(dtype) keepdims = np.random.randint(0, 2) + + allow_nan = np.random.randint(0, 2) + if allow_nan: + total_nans = np.random.randint(0, dat.size//10+1) + dat.ravel()[np.random.choice( + dat.size, total_nans, replace=False)] = np.nan + + allow_inf = np.random.randint(0, 2) + if allow_inf: + r = np.random.randint(0, 3) + total_infs = np.random.randint(0, dat.size//20+1) + if r == 0: + total_pos_infs, total_neg_infs = total_infs, 0 + elif r == 1: + total_pos_infs, total_neg_infs = 0, total_infs + else: + total_pos_infs = total_neg_infs = total_infs // 2 + dat.ravel()[np.random.choice( + dat.size, total_pos_infs, replace=False)] = np.inf + dat.ravel()[np.random.choice( + dat.size, total_neg_infs, replace=False)] = -np.inf + if multi_axes: axis_flags = np.random.randint(0, 2, size=ndim) axes = [] @@ -601,16 +629,22 @@ def test_reduce_inner(numpy_reduce_func, nd_reduce_func, multi_axes): axes = np.random.randint(0, ndim) numpy_ret = numpy_reduce_func(dat, axis=axes, keepdims=keepdims) - ndarray_ret = nd_reduce_func(mx.nd.array(dat), axis=axes, keepdims=keepdims) + mx_arr = mx.nd.array(dat, dtype=dtype) + ndarray_ret = nd_reduce_func(mx_arr, axis=axes, keepdims=keepdims) if type(ndarray_ret) is mx.ndarray.NDArray: ndarray_ret = ndarray_ret.asnumpy() assert (ndarray_ret.shape == numpy_ret.shape) or \ (ndarray_ret.shape == (1,) and numpy_ret.shape == ()), "nd:%s, numpy:%s" \ %(ndarray_ret.shape, numpy_ret.shape) - err = np.square(ndarray_ret - numpy_ret).mean() - assert err < 1E-4 + if check_dtype: + assert ndarray_ret.dtype == numpy_ret.dtype,\ + (ndarray_ret.dtype, numpy_ret.dtype) + if allow_almost_equal: + assert_array_almost_equal(ndarray_ret, numpy_ret, decimal=decimal) + else: + assert_array_equal(ndarray_ret, numpy_ret) test_reduce_inner(lambda data, axis, keepdims:np_reduce(data, axis, keepdims, np.sum), - mx.nd.sum, True) + mx.nd.sum, True, allow_almost_equal=True) test_reduce_inner(lambda data, axis, keepdims:np_reduce(data, axis, keepdims, np.max), mx.nd.max, True) test_reduce_inner(lambda data, axis, keepdims:np_reduce(data, axis, keepdims, np.min), @@ -619,10 +653,10 @@ def test_reduce_inner(numpy_reduce_func, nd_reduce_func, multi_axes): # Force numpy to match mxnet's float32. test_reduce_inner(lambda data, axis, keepdims:np_reduce(np.float32(data), axis, keepdims, np.argmax), - mx.nd.argmax, False) + mx.nd.argmax, False, check_dtype=False) test_reduce_inner(lambda data, axis, keepdims:np_reduce(np.float32(data), axis, keepdims, np.argmin), - mx.nd.argmin, False) + mx.nd.argmin, False, check_dtype=False) @with_seed() diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 376e177d0659..c5fb310e0c51 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -9295,23 +9295,6 @@ def test_sample_normal_default_shape(): assert s.shape == (1, 1) -def test_min_max_inf(): - dtypes = [np.float32, np.double] - elem_list = [-1, 1, 0, np.inf, -np.inf] - - for dtype in dtypes: - for a in elem_list: - for b in elem_list: - data_np = np.array([a, b], dtype=dtype) - data_mx = mx.nd.array(data_np, dtype=dtype) - - min_data_np, max_data_np = data_np.min(), data_np.max() - min_data_mx, max_data_mx = data_mx.min(), data_mx.max() - - assert_array_equal(min_data_np, min_data_mx.asnumpy()) - assert_array_equal(max_data_np, max_data_mx.asnumpy()) - - def test_large_tensor_disabled_err_msg(): LARGE_X = 4300000000 MEDIUM_X = 1000000000