From afc6bf366cd5ded94d4d2a273bbda574fb23677e Mon Sep 17 00:00:00 2001 From: Fan Date: Tue, 2 Jul 2019 17:37:16 +0800 Subject: [PATCH 1/3] add numpy compatible copysign --- python/mxnet/ndarray/numpy/_op.py | 45 +++++++++++++- python/mxnet/numpy/multiarray.py | 45 +++++++++++++- python/mxnet/symbol/numpy/_symbol.py | 28 ++++++++- src/operator/mshadow_op.h | 10 ++++ .../numpy/np_elemwise_broadcast_op.cc | 40 +++++++++++++ .../numpy/np_elemwise_broadcast_op.cu | 21 +++++++ src/operator/operator_tune.cc | 5 ++ tests/python/unittest/test_numpy_op.py | 60 +++++++++++++++++++ 8 files changed, 251 insertions(+), 3 deletions(-) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 671345c9a546..602b8fc31210 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -33,7 +33,7 @@ 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean', - 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices'] + 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign'] @set_module('mxnet.ndarray.numpy') @@ -2432,3 +2432,46 @@ def indices(dimensions, dtype=_np.int32, ctx=None): else: raise ValueError("The dimensions must be sequence of ints") # pylint: enable=redefined-outer-name + + +@set_module('mxnet.ndarray.numpy') +def copysign(x1, x2, out=None): + r"""copysign(x1, x2, out=None) + + Change the sign of x1 to that of x2, element-wise. + + If `x2` is a scalar, its sign will be copied to all elements of `x1`. + + Parameters + ---------- + x1 : ndarray or scalar + Values to change the sign of. + x2 : ndarray or scalar + The sign of `x2` is copied to `x1`. + out : ndarray or None, optional + A location into which the result is stored. It must be of the + right shape and right type to hold the output. If not provided + or `None`,a freshly-allocated array is returned. + + Returns + ------- + out : ndarray or scalar + The values of `x1` with the sign of `x2`. + This is a scalar if both `x1` and `x2` are scalars. + + Examples + -------- + >>> np.copysign(1.3, -1) + -1.3 + >>> 1/np.copysign(0, 1) + inf + >>> 1/np.copysign(0, -1) + -inf + + >>> a = np.array([-1, 0, 1]) + >>> np.copysign(a, -1.1) + array([-1., -0., -1.]) + >>> np.copysign(a, np.arange(3)-1) + array([-1., 0., 1.]) + """ + return _ufunc_helper(x1, x2, _npi.copysign, _np.copysign, _npi.copysign_scalar, _npi.rcopysign_scalar, out) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 1f8aa92f9851..2ad02cde535b 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -52,7 +52,7 @@ 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', - 'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices'] + 'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign'] # Return code for dispatching indexing function call _NDARRAY_UNSUPPORTED_INDEXING = -1 @@ -3935,3 +3935,46 @@ def indices(dimensions, dtype=_np.int32, ctx=None): """ return _mx_nd_np.indices(dimensions=dimensions, dtype=dtype, ctx=ctx) # pylint: enable=redefined-outer-name + + +@set_module('mxnet.numpy') +def copysign(x1, x2, out=None): + r"""copysign(x1, x2, out=None) + + Change the sign of x1 to that of x2, element-wise. + + If `x2` is a scalar, its sign will be copied to all elements of `x1`. + + Parameters + ---------- + x1 : ndarray or scalar + Values to change the sign of. + x2 : ndarray or scalar + The sign of `x2` is copied to `x1`. + out : ndarray or None, optional + A location into which the result is stored. It must be of the + right shape and right type to hold the output. If not provided + or `None`,a freshly-allocated array is returned. + + Returns + ------- + out : ndarray or scalar + The values of `x1` with the sign of `x2`. + This is a scalar if both `x1` and `x2` are scalars. + + Examples + -------- + >>> np.copysign(1.3, -1) + -1.3 + >>> 1/np.copysign(0, 1) + inf + >>> 1/np.copysign(0, -1) + -inf + + >>> a = np.array([-1, 0, 1]) + >>> np.copysign(a, -1.1) + array([-1., -0., -1.]) + >>> np.copysign(a, np.arange(3)-1) + array([-1., 0., 1.]) + """ + return _mx_nd_np.copysign(x1, x2, out=out) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 077008aba119..e34887d154c5 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -35,7 +35,7 @@ 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean', - 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices'] + 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign'] def _num_outputs(sym): @@ -2744,4 +2744,30 @@ def indices(dimensions, dtype=_np.int32, ctx=None): # pylint: enable=redefined-outer-name +@set_module('mxnet.symbol.numpy') +def copysign(x1, x2, out=None): + r"""copysign(x1, x2, out=None) + + Change the sign of x1 to that of x2, element-wise. + + If `x2` is a scalar, its sign will be copied to all elements of `x1`. + + Parameters + ---------- + x1 : _Symbol or scalar + Values to change the sign of. + x2 : _Symbol or scalar + The sign of `x2` is copied to `x1`. + out : _Symbol or None + Dummy parameter to keep the consistency with the ndarray counterpart. + + Returns + ------- + out : _Symbol + The values of `x1` with the sign of `x2`. + This is a scalar if both `x1` and `x2` are scalars. + """ + return _ufunc_helper(x1, x2, _npi.copysign, _np.copysign, _npi.copysign_scalar, _npi.rcopysign_scalar, out) + + _set_np_symbol_class(_Symbol) diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 616192e4af57..f3d24b2c119e 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -417,6 +417,16 @@ MXNET_BINARY_MATH_OP(rdiv, math::id(b) / math::id(a)); MXNET_BINARY_MATH_OP(rdiv_grad, -math::id(b) / math::sqr(a)); +MXNET_BINARY_MATH_OP(copysign, (a >= 0 && b >= 0) || (a < 0 && b < 0) ? a : -a); + +MXNET_BINARY_MATH_OP(copysign_grad, (a >= 0 && b >= 0) || (a < 0 && b < 0) ? 1: -1); + +MXNET_BINARY_MATH_OP(copysign_rgrad, 0); + +MXNET_BINARY_MATH_OP(rcopysign, (b >= 0 && a >= 0) || (b < 0 && a < 0) ? b : -b); + +MXNET_BINARY_MATH_OP(rcopysign_grad, 0); + struct mod : public mxnet_op::tunable { template MSHADOW_XINLINE static typename enable_if::value, DType>::type diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc index 697657d84dd5..54db4dd39217 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op.cc @@ -76,6 +76,26 @@ MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_power) .set_attr("FCompute", BinaryBroadcastCompute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_broadcast_power"}); +MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_copysign) +.describe(R"code()code" ADD_FILELINE) +.set_attr("FCompute", BinaryBroadcastCompute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_copysign"}); + +NNVM_REGISTER_OP(_backward_npi_copysign) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 1}}; + }) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); + MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_add_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseNone{"_copy"}); @@ -108,5 +128,25 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rpower_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseOut{"_backward_rpower_scalar"}); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_copysign_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_copysign_scalar"}); + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rcopysign_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_rcopysign_scalar"}); + +MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_copysign_scalar) +.add_argument("scalar", "float", "scalar value") +.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.set_attr("FCompute", + BinaryScalarOp::Backward); + +MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rcopysign_scalar) +.add_argument("scalar", "float", "scalar value") +.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.set_attr("FCompute", BinaryScalarOp::Backward< + cpu, mshadow_op::rcopysign_grad>); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu b/src/operator/numpy/np_elemwise_broadcast_op.cu index ac8def2af2c2..5dd65cc859a9 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cu +++ b/src/operator/numpy/np_elemwise_broadcast_op.cu @@ -42,6 +42,13 @@ NNVM_REGISTER_OP(_npi_mod) NNVM_REGISTER_OP(_npi_power) .set_attr("FCompute", BinaryBroadcastCompute); +NNVM_REGISTER_OP(_npi_copysign) +.set_attr("FCompute", BinaryBroadcastCompute); + +NNVM_REGISTER_OP(_backward_npi_copysign) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); + NNVM_REGISTER_OP(_npi_add_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); @@ -66,5 +73,19 @@ NNVM_REGISTER_OP(_npi_power_scalar) NNVM_REGISTER_OP(_npi_rpower_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); +NNVM_REGISTER_OP(_npi_copysign_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + +NNVM_REGISTER_OP(_npi_rcopysign_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + +NNVM_REGISTER_OP(_backward_npi_copysign_scalar) +.set_attr("FCompute", + BinaryScalarOp::Backward); + +NNVM_REGISTER_OP(_backward_npi_rcopysign_scalar) +.set_attr("FCompute", + BinaryScalarOp::Backward); + } // namespace op } // namespace mxnet diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index 98ce14e7bf05..51595254cafc 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -328,6 +328,11 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::elu); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rpower_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_rgrad); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::copysign); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rcopysign); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::copysign_grad); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::copysign_rgrad); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rcopysign_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::xelu_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gelu_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::prelu_grad); // NOLINT() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index c5b0907fb7a8..f544a65e3b3f 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1853,6 +1853,66 @@ def hybrid_forward(self, F, x): assert_almost_equal(mx_ret.asnumpy(), np_ret, atol=1e-5, rtol=1e-4) +@with_seed() +@use_np +def test_np_copysign(): + class TestCopysign(HybridBlock): + def __init__(self): + super(TestCopysign, self).__init__() + + def hybrid_forward(self, F, a1, a2): + return F.np.copysign(a1, a2) + + def get_grad(a1, a2): + sign = _np.logical_or(_np.logical_and(a1 < 0, a2 < 0), + _np.logical_and(a1 >= 0, a2 >= 0)) + sign = 2 * sign.astype(int) - 1 + sign = sign.reshape(-1, *a1.shape) + sign = _np.sum(sign, axis=0) + return sign, _np.zeros_like(a2) + + shapes = [ + (), + (1), + (2, 1), + (3, 2, 1), + (4, 3, 2, 1), + (2, 4, 3, 2, 1) + ] + types = ['float16', 'float32', 'float64', 'int8', 'int32', 'int64'] + for a1shape in shapes: + for a2shape in shapes: + for hybridize in [True, False]: + for dtype in types: + test_copysign = TestCopysign() + if hybridize: + test_copysign.hybridize() + rtol = 1e-3 + atol = 1e-5 + a1_np = _np.array(_np.random.uniform(-1.0, 1.0, a1shape), dtype=dtype) + a2_np = _np.array(_np.random.uniform(-1.0, 1.0, a2shape), dtype=dtype) + a1 = mx.nd.array(a1_np).as_np_ndarray() + a2 = mx.nd.array(a2_np).as_np_ndarray() + a1.attach_grad() + a2.attach_grad() + expected_np = _np.copysign(a1_np, a2_np) + with mx.autograd.record(): + mx_out = test_copysign(a1, a2) + assert mx_out.shape == expected_np.shape + assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol) + + # Test gradient + mx_out.backward() + a1_grad, a2_grad = get_grad(a1_np, a2_np) + assert_almost_equal(a1.grad.asnumpy(), a1_grad, rtol=rtol, atol=atol) + assert_almost_equal(a2.grad.asnumpy(), a2_grad, rtol=rtol, atol=atol) + + # Test imperative once again + mx_out = np.copysign(a1, a2) + expected_np = _np.copysign(a1_np, a2_np) + assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol) + + if __name__ == '__main__': import nose nose.runmodule() From dd838882781dde40b53d6ad28f992caf644999ba Mon Sep 17 00:00:00 2001 From: Fan Date: Sat, 14 Sep 2019 01:20:04 +0800 Subject: [PATCH 2/3] fix scalar op registration error --- python/mxnet/ndarray/numpy/_op.py | 8 ++++++++ python/mxnet/numpy/multiarray.py | 8 ++++++++ python/mxnet/symbol/numpy/_symbol.py | 8 ++++++++ src/operator/numpy/np_elemwise_broadcast_op.cc | 14 +++++--------- src/operator/numpy/np_elemwise_broadcast_op.cu | 2 +- tests/python/unittest/test_numpy_op.py | 4 ++-- 6 files changed, 32 insertions(+), 12 deletions(-) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 602b8fc31210..b8e4f3f28b86 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -2459,6 +2459,14 @@ def copysign(x1, x2, out=None): The values of `x1` with the sign of `x2`. This is a scalar if both `x1` and `x2` are scalars. + Notes + ------- + This function differs from the original `numpy.copysign + `_ in + the following aspects: + + - ``where`` param is not supported. + Examples -------- >>> np.copysign(1.3, -1) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 2ad02cde535b..632cfadf86a6 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -3962,6 +3962,14 @@ def copysign(x1, x2, out=None): The values of `x1` with the sign of `x2`. This is a scalar if both `x1` and `x2` are scalars. + Notes + ------- + This function differs from the original `numpy.copysign + `_ in + the following aspects: + + - ``where`` param is not supported. + Examples -------- >>> np.copysign(1.3, -1) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index e34887d154c5..5a38f81f9102 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -2766,6 +2766,14 @@ def copysign(x1, x2, out=None): out : _Symbol The values of `x1` with the sign of `x2`. This is a scalar if both `x1` and `x2` are scalars. + + Notes + ------- + This function differs from the original `numpy.copysign + `_ in + the following aspects: + + - ``where`` param is not supported. """ return _ufunc_helper(x1, x2, _npi.copysign, _np.copysign, _npi.copysign_scalar, _npi.rcopysign_scalar, out) diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc index 54db4dd39217..a9254e8a7d02 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op.cc @@ -94,7 +94,7 @@ NNVM_REGISTER_OP(_backward_npi_copysign) return std::vector{ResourceRequest::kTempSpace}; }) .set_attr("FCompute", BinaryBroadcastBackwardUseIn); + mshadow_op::copysign_rgrad>); MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_add_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) @@ -136,17 +136,13 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rcopysign_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_rcopysign_scalar"}); -MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_copysign_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_backward_npi_copysign_scalar) .set_attr("FCompute", BinaryScalarOp::Backward); -MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rcopysign_scalar) -.add_argument("scalar", "float", "scalar value") -.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) -.set_attr("FCompute", BinaryScalarOp::Backward< - cpu, mshadow_op::rcopysign_grad>); +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_backward_npi_rcopysign_scalar) +.set_attr("FCompute", + BinaryScalarOp::Backward); } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu b/src/operator/numpy/np_elemwise_broadcast_op.cu index 5dd65cc859a9..ecf8e8531334 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cu +++ b/src/operator/numpy/np_elemwise_broadcast_op.cu @@ -80,7 +80,7 @@ NNVM_REGISTER_OP(_npi_rcopysign_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); NNVM_REGISTER_OP(_backward_npi_copysign_scalar) -.set_attr("FCompute", +.set_attr("FCompute", BinaryScalarOp::Backward); NNVM_REGISTER_OP(_backward_npi_rcopysign_scalar) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index f544a65e3b3f..6fa2199bb4f0 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1891,8 +1891,8 @@ def get_grad(a1, a2): atol = 1e-5 a1_np = _np.array(_np.random.uniform(-1.0, 1.0, a1shape), dtype=dtype) a2_np = _np.array(_np.random.uniform(-1.0, 1.0, a2shape), dtype=dtype) - a1 = mx.nd.array(a1_np).as_np_ndarray() - a2 = mx.nd.array(a2_np).as_np_ndarray() + a1 = np.array(a1_np, dtype=dtype) + a2 = np.array(a2_np, dtype=dtype) a1.attach_grad() a2.attach_grad() expected_np = _np.copysign(a1_np, a2_np) From 8ff8a13666a6fd34c411b562336ae2587ac09945 Mon Sep 17 00:00:00 2001 From: Fan Date: Sun, 15 Sep 2019 16:18:27 +0800 Subject: [PATCH 3/3] add test --- tests/python/unittest/test_numpy_op.py | 47 +++++++++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 6fa2199bb4f0..1f2af8dfa4a9 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1870,7 +1870,17 @@ def get_grad(a1, a2): sign = sign.reshape(-1, *a1.shape) sign = _np.sum(sign, axis=0) return sign, _np.zeros_like(a2) - + + def get_grad_left(a1, a2): + sign = _np.logical_or(_np.logical_and(a1 < 0, a2 < 0), + _np.logical_and(a1 >= 0, a2 >= 0)) + sign = 2 * sign.astype(int) - 1 + sign = sign.reshape(a1.shape) + return sign + + def get_grad_right(a1, a2): + return _np.zeros_like(a2) + shapes = [ (), (1), @@ -1911,6 +1921,41 @@ def get_grad(a1, a2): mx_out = np.copysign(a1, a2) expected_np = _np.copysign(a1_np, a2_np) assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol) + + types = ['float16', 'float32', 'float64'] + for x_shape in shapes: + for dtype in types: + # Test left + x_np = _np.array(_np.random.uniform(-2.0, 2.0, x_shape), dtype=dtype) + scalar = _np.random.uniform(-2.0, 2.0) + x = np.array(x_np, dtype=dtype) + x.attach_grad() + expected_np = _np.copysign(x_np, scalar) + with mx.autograd.record(): + mx_out = np.copysign(x, scalar) + assert mx_out.shape == expected_np.shape + assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol) + + # Test gradient + mx_out.backward() + x_grad = get_grad_left(x_np, scalar) + assert_almost_equal(x.grad.asnumpy(), x_grad, rtol=rtol, atol=atol) + + # Test right + x_np = _np.array(_np.random.uniform(-2.0, 2.0, x_shape), dtype=dtype) + scalar = _np.random.uniform(-2.0, 2.0) + x = np.array(x_np, dtype=dtype) + x.attach_grad() + expected_np = _np.copysign(scalar, x_np) + with mx.autograd.record(): + mx_out = np.copysign(scalar, x) + assert mx_out.shape == expected_np.shape + assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol) + + # Test gradient + mx_out.backward() + x_grad = get_grad_right(scalar, x_np) + assert_almost_equal(x.grad.asnumpy(), x_grad, rtol=rtol, atol=atol) if __name__ == '__main__':