From 3a10abb65d548c8160ce3d59f662712556db058d Mon Sep 17 00:00:00 2001 From: Yiyan66 Date: Sat, 4 Jul 2020 15:10:19 +0000 Subject: [PATCH 1/7] temp --- tests/python/unittest/test_numpy_op.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index ed9886ea8f75..a507677c2523 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3071,7 +3071,7 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): @with_seed() @use_np -@pytest.mark.skip(reason='https://github.com/apache/incubator-mxnet/issues/16848') +# @pytest.mark.skip(reason='https://github.com/apache/incubator-mxnet/issues/16848') def test_np_mixed_precision_binary_funcs(): itypes = [np.bool, np.int8, np.int32, np.int64] ftypes = [np.float16, np.float32, np.float64] @@ -3084,6 +3084,10 @@ def __init__(self, func): def hybrid_forward(self, F, a, b, *args, **kwargs): return getattr(F.np, self._func)(a, b) + # if (func in ['multiply', 'mod', 'equal', 'not_equal', 'greater', + # 'greater_equal', 'less', 'less_equal']) and \ + # (lshape == () or rshape == ()) : + # return np_func = getattr(_np, func) mx_func = TestMixedBinary(func) np_test_x1 = _np.random.uniform(low, high, lshape).astype(ltype) From f69658a74813f761d0a34625416bb16aec9cc0ac Mon Sep 17 00:00:00 2001 From: Yiyan66 Date: Sat, 4 Jul 2020 15:33:15 +0000 Subject: [PATCH 2/7] change test --- tests/python/unittest/test_numpy_op.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index a507677c2523..81449f358a70 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3071,7 +3071,6 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): @with_seed() @use_np -# @pytest.mark.skip(reason='https://github.com/apache/incubator-mxnet/issues/16848') def test_np_mixed_precision_binary_funcs(): itypes = [np.bool, np.int8, np.int32, np.int64] ftypes = [np.float16, np.float32, np.float64] @@ -3084,10 +3083,10 @@ def __init__(self, func): def hybrid_forward(self, F, a, b, *args, **kwargs): return getattr(F.np, self._func)(a, b) - # if (func in ['multiply', 'mod', 'equal', 'not_equal', 'greater', - # 'greater_equal', 'less', 'less_equal']) and \ - # (lshape == () or rshape == ()) : - # return + if (func in ['multiply', 'mod', 'equal', 'not_equal', 'greater', + 'greater_equal', 'less', 'less_equal']) and \ + (lshape == () or rshape == ()) : + return np_func = getattr(_np, func) mx_func = TestMixedBinary(func) np_test_x1 = _np.random.uniform(low, high, lshape).astype(ltype) From dfe34422611ae85d2a61e120379771e3d5040100 Mon Sep 17 00:00:00 2001 From: Yiyan66 Date: Mon, 6 Jul 2020 06:13:56 +0000 Subject: [PATCH 3/7] fix bad func call --- tests/python/unittest/test_numpy_op.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 81449f358a70..7adbcc545d0a 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3143,16 +3143,17 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): 'mod': (1.0, 5.0, None, None), 'power': (1.0, 3.0, lambda y, x1, x2: _np.power(x1, x2 - 1.0) * x2, lambda y, x1, x2: _np.power(x1, x2) * _np.log(x1)), - 'equal': (0.0, 2.0, None, None), - 'not_equal': (0.0, 2.0, None, None), - 'greater': (0.0, 2.0, None, None), - 'less': (0.0, 2.0, None, None), - 'greater_equal': (0.0, 2.0, None, None), - 'less_equal': (0.0, 2.0, None, None), - 'logical_and': (0.0, 2.0, None, None), - 'logical_or': (0.0, 2.0, None, None), - 'logical_xor': (0.0, 2.0, None, None), } + if not has_tvm_ops(): + funcs['equal'] = (0.0, 2.0, None, None) + funcs['not_equal'] = (0.0, 2.0, None, None) + funcs['greater'] = (0.0, 2.0, None, None) + funcs['less'] = (0.0, 2.0, None, None) + funcs['greater_equal'] = (0.0, 2.0, None, None) + funcs['less_equal'] = (0.0, 2.0, None, None) + funcs['logical_and'] = (0.0, 2.0, None, None) + funcs['logical_or'] = (0.0, 2.0, None, None) + funcs['logical_xor'] = (0.0, 2.0, None, None) shape_pairs = [((3, 2), (3, 2)), ((3, 2), (3, 1)), @@ -3168,7 +3169,7 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): for func, func_data in funcs.items(): low, high, lgrad, rgrad = func_data for lshape, rshape in shape_pairs: - for type1, type2 in itertools.product(itypes, ftypes): + for type1, type2 in itertools.product(ftypes, ftypes): check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2) check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type2, type1) From 8cdbb02c86a13a78250e026f255515086e0f0aad Mon Sep 17 00:00:00 2001 From: Yiyan66 Date: Mon, 6 Jul 2020 07:27:21 +0000 Subject: [PATCH 4/7] test --- tests/python/unittest/test_numpy_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 7adbcc545d0a..7ee46a0d315c 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3169,7 +3169,7 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): for func, func_data in funcs.items(): low, high, lgrad, rgrad = func_data for lshape, rshape in shape_pairs: - for type1, type2 in itertools.product(ftypes, ftypes): + for type1, type2 in itertools.product(itypes, ftypes): check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type1, type2) check_mixed_precision_binary_func(func, low, high, lshape, rshape, lgrad, rgrad, type2, type1) From cd5d6c9e91156e29c068237558f18144502d3dfc Mon Sep 17 00:00:00 2001 From: Yiyan66 Date: Thu, 23 Jul 2020 04:01:23 +0000 Subject: [PATCH 5/7] rectify --- .../numpy/np_elemwise_broadcast_logic_op.cc | 42 +++++++++++++++++-- tests/python/unittest/test_numpy_op.py | 19 ++++----- 2 files changed, 48 insertions(+), 13 deletions(-) diff --git a/src/operator/numpy/np_elemwise_broadcast_logic_op.cc b/src/operator/numpy/np_elemwise_broadcast_logic_op.cc index b191553f16da..9aacbc02b061 100644 --- a/src/operator/numpy/np_elemwise_broadcast_logic_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_logic_op.cc @@ -79,7 +79,9 @@ TBlob PrependAxes(const TBlob& src, const int dst_ndim) { return src.reshape(dst_shape); } -struct TVMBinaryBroadcastCompute { + +template +struct GetBinaryBroadcastCompute { const char* func; void operator()(const nnvm::NodeAttrs& attrs, const mxnet::OpContext& ctx, @@ -96,6 +98,38 @@ struct TVMBinaryBroadcastCompute { std::vector type_codes; std::vector values; + const TBlob& a = inputs[0]; + const TBlob& b = inputs[1]; + if (a.type_flag_ != b.type_flag_) { + if (outputs[0].shape_.Size() == 0U) return; + mxnet::TShape new_lshape, new_rshape, new_oshape; + const TBlob& lhs = inputs[0]; + const TBlob& rhs = inputs[1]; + const TBlob& out = outputs[0]; + int ndim = BinaryBroadcastShapeCompact(lhs.shape_, rhs.shape_, out.shape_, + &new_lshape, &new_rshape, &new_oshape); + if (!ndim) { + ElemwiseBinaryOp::ComputeLogic(attrs, ctx, inputs, req, outputs); + } else { + if (req[0] == kNullOp) return; + mshadow::Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH_WITH_BOOL(lhs.type_flag_, DType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(rhs.type_flag_, EType, { + BROADCAST_NDIM_SWITCH(ndim, NDim, { + mshadow::Shape oshape = new_oshape.get(); + mshadow::Shape lstride = mxnet_op::calc_stride(new_lshape.get()); + mshadow::Shape rstride = mxnet_op::calc_stride(new_rshape.get()); + mxnet_op::Kernel, xpu>:: + template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape, + lhs.dptr(), rhs.dptr(), + out.dptr()); + }); + }); + }); + } + return; + } + const int ondim = outputs[0].shape_.ndim(); const size_t num_args = inputs.size() + outputs.size(); type_codes.resize(num_args); @@ -146,13 +180,15 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC(logical_xor); #define MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_CPU(name) \ NNVM_REGISTER_OP(_npi_##name) \ - .set_attr("FCompute", TVMBinaryBroadcastCompute{func_##name##_cpu}) + .set_attr("FCompute", GetBinaryBroadcastCompute{func_##name##_cpu}) #if MXNET_USE_CUDA #define MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(name) \ NNVM_REGISTER_OP(_npi_##name) \ - .set_attr("FCompute", TVMBinaryBroadcastCompute{func_##name##_gpu}) + .set_attr("FCompute", GetBinaryBroadcastCompute{func_##name##_gpu}) MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(equal); MXNET_OPERATOR_REGISTER_NP_BINARY_LOGIC_GPU(not_equal); diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 7ee46a0d315c..81449f358a70 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3143,17 +3143,16 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): 'mod': (1.0, 5.0, None, None), 'power': (1.0, 3.0, lambda y, x1, x2: _np.power(x1, x2 - 1.0) * x2, lambda y, x1, x2: _np.power(x1, x2) * _np.log(x1)), + 'equal': (0.0, 2.0, None, None), + 'not_equal': (0.0, 2.0, None, None), + 'greater': (0.0, 2.0, None, None), + 'less': (0.0, 2.0, None, None), + 'greater_equal': (0.0, 2.0, None, None), + 'less_equal': (0.0, 2.0, None, None), + 'logical_and': (0.0, 2.0, None, None), + 'logical_or': (0.0, 2.0, None, None), + 'logical_xor': (0.0, 2.0, None, None), } - if not has_tvm_ops(): - funcs['equal'] = (0.0, 2.0, None, None) - funcs['not_equal'] = (0.0, 2.0, None, None) - funcs['greater'] = (0.0, 2.0, None, None) - funcs['less'] = (0.0, 2.0, None, None) - funcs['greater_equal'] = (0.0, 2.0, None, None) - funcs['less_equal'] = (0.0, 2.0, None, None) - funcs['logical_and'] = (0.0, 2.0, None, None) - funcs['logical_or'] = (0.0, 2.0, None, None) - funcs['logical_xor'] = (0.0, 2.0, None, None) shape_pairs = [((3, 2), (3, 2)), ((3, 2), (3, 1)), From 1712bdd2087de3d6ccc03260648b08e7dc846f0c Mon Sep 17 00:00:00 2001 From: Yiyan66 Date: Thu, 23 Jul 2020 04:29:18 +0000 Subject: [PATCH 6/7] doc --- tests/python/unittest/test_numpy_op.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 81449f358a70..a4af7ee4351f 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3086,6 +3086,9 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): if (func in ['multiply', 'mod', 'equal', 'not_equal', 'greater', 'greater_equal', 'less', 'less_equal']) and \ (lshape == () or rshape == ()) : + # the behaviors of infer type in dealing with the input shape of '()' are different between np and onp + # logcial ops: when two numbers are only different in precision, NumPy also has a weird behavior + # thus, skip the tests return np_func = getattr(_np, func) mx_func = TestMixedBinary(func) From 7342e0616a53990825e8df94378e2db715abf237 Mon Sep 17 00:00:00 2001 From: Yiyan66 Date: Tue, 28 Jul 2020 03:56:10 +0000 Subject: [PATCH 7/7] change test --- tests/python/unittest/test_numpy_op.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index a4af7ee4351f..24fe37d8ede8 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3087,9 +3087,23 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): 'greater_equal', 'less', 'less_equal']) and \ (lshape == () or rshape == ()) : # the behaviors of infer type in dealing with the input shape of '()' are different between np and onp + # for example, + # mx_test_x1 = np.random.uniform(-2, 2, (2,3)).astype(np.float32) + # mx_test_x2 = np.random.uniform(-2, 2, ()).astype(np.float16) + # np_out = _np.mod(mx_test_x1.asnumpy(), mx_test_x2.asnumpy()) # float16 + # mx_out = np.mod(mx_test_x1, mx_test_x2) # float32 + # logcial ops: when two numbers are only different in precision, NumPy also has a weird behavior + # for example, + # a = np.array([[1.441]], dtype = np.float16) + # b = np.array(1.4413278, dtype = np.float32) + # c = np.array([1.4413278], dtype = np.float32) + # np.greater(a,b), np.greater(a,c) # True True + # _np.greater(a.asnumpy(),b.asnumpy()), _np.greater(a.asnumpy(),c.asnumpy()) # False True + # thus, skip the tests return + np_func = getattr(_np, func) mx_func = TestMixedBinary(func) np_test_x1 = _np.random.uniform(low, high, lshape).astype(ltype)