From 1783ed1d7c7bce9d572f9cd053067d0cbcfddcfa Mon Sep 17 00:00:00 2001 From: Ying Date: Thu, 15 Aug 2019 13:12:52 +0800 Subject: [PATCH] numpy operator hypot * rebase master --- python/mxnet/ndarray/numpy/_op.py | 51 ++++++++++++- python/mxnet/numpy/multiarray.py | 51 ++++++++++++- python/mxnet/symbol/numpy/_symbol.py | 37 ++++++++- .../numpy/np_elemwise_broadcast_op.cc | 54 +++++++++++++ .../numpy/np_elemwise_broadcast_op.cu | 7 ++ .../elemwise_binary_scalar_op_extended.cc | 1 + tests/python/gpu/test_operator_gpu.py | 76 +++++++++++++++++++ tests/python/unittest/test_numpy_op.py | 74 ++++++++++++++++++ 8 files changed, 348 insertions(+), 3 deletions(-) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index d7f3fd1ace54..f2f96e3133e6 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -27,7 +27,7 @@ from ..ndarray import NDArray __all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot', - 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate'] + 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'hypot'] @set_module('mxnet.ndarray.numpy') @@ -705,3 +705,52 @@ def concatenate(seq, axis=0, out=None): The concatenated array. """ return _npi.concatenate(*seq, dim=axis, out=out) + + +@set_module('mxnet.ndarray.numpy') +def hypot(x1, x2, out=None): + r""" + hypot(x1, x2, out=None) + + Given the "legs" of a right triangle, return its hypotenuse. + + Equivalent to ``sqrt(x1**2 + x2**2)``, element-wise. If `x1` or + `x2` is scalar_like (i.e., unambiguously cast-able to a scalar type), + it is broadcast for use with each element of the other argument. + + Parameters + ---------- + x1, x2 : array_like + Leg of the triangle(s). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. A tuple (possible only as a + keyword argument) must have length equal to the number of outputs. + + Returns + ------- + z : ndarray + The hypotenuse of the triangle(s). + This is a scalar if both `x1` and `x2` are scalars. + + Notes + ----- + This function differs from the original numpy.arange in the following aspects: + - Only support float16, float32 and float64. + + Examples + -------- + >>> np.hypot(3*np.ones((3, 3)), 4*np.ones((3, 3))) + array([[ 5., 5., 5.], + [ 5., 5., 5.], + [ 5., 5., 5.]]) + + Example showing broadcast of scalar_like argument: + + >>> np.hypot(3*np.ones((3, 3)), [4]) + array([[ 5., 5., 5.], + [ 5., 5., 5.], + [ 5., 5., 5.]]) + """ + return _ufunc_helper(x1, x2, _npi.hypot, _np.hypot, _npi.hypot_scalar, None, out) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 8988b4eb19c9..4bd9190303be 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -45,7 +45,7 @@ __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', - 'concatenate'] + 'concatenate', 'hypot'] # This function is copied from ndarray.py since pylint @@ -1877,3 +1877,52 @@ def concatenate(seq, axis=0, out=None): The concatenated array. """ return _mx_nd_np.concatenate(seq, axis=axis, out=out) + + +@set_module('mxnet.numpy') +def hypot(x1, x2, out=None): + r""" + hypot(x1, x2, out=None) + + Given the "legs" of a right triangle, return its hypotenuse. + + Equivalent to ``sqrt(x1**2 + x2**2)``, element-wise. If `x1` or + `x2` is scalar_like (i.e., unambiguously cast-able to a scalar type), + it is broadcast for use with each element of the other argument. + + Parameters + ---------- + x1, x2 : array_like + Leg of the triangle(s). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. A tuple (possible only as a + keyword argument) must have length equal to the number of outputs. + + Returns + ------- + z : ndarray + The hypotenuse of the triangle(s). + This is a scalar if both `x1` and `x2` are scalars. + + Notes + ----- + This function differs from the original numpy.arange in the following aspects: + - Only support float16, float32 and float64. + + Examples + -------- + >>> np.hypot(3*np.ones((3, 3)), 4*np.ones((3, 3))) + array([[ 5., 5., 5.], + [ 5., 5., 5.], + [ 5., 5., 5.]]) + + Example showing broadcast of scalar_like argument: + + >>> np.hypot(3*np.ones((3, 3)), [4]) + array([[ 5., 5., 5.], + [ 5., 5., 5.], + [ 5., 5., 5.]]) + """ + return _mx_nd_np.hypot(x1, x2, out=out) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index a6699d60871a..c518753c0573 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -30,7 +30,7 @@ from . import _internal as _npi __all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot', - 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate'] + 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'hypot'] def _num_outputs(sym): @@ -1335,4 +1335,39 @@ def concatenate(seq, axis=0, out=None): return _npi.concatenate(*seq, dim=axis, out=out) +@set_module('mxnet.symbol.numpy') +def hypot(x1, x2, out=None): + r""" + hypot(x1, x2, out=None) + + Given the "legs" of a right triangle, return its hypotenuse. + + Equivalent to ``sqrt(x1**2 + x2**2)``, element-wise. If `x1` or + `x2` is scalar_like (i.e., unambiguously cast-able to a scalar type), + it is broadcast for use with each element of the other argument. + + Parameters + ---------- + x1, x2 : array_like + Leg of the triangle(s). + out : ndarray, None, or tuple of ndarray and None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. A tuple (possible only as a + keyword argument) must have length equal to the number of outputs. + + Returns + ------- + z : ndarray + The hypotenuse of the triangle(s). + This is a scalar if both `x1` and `x2` are scalars. + + Notes + ----- + This function differs from the original numpy.arange in the following aspects: + - Only support float16, float32 and float64. + """ + return _ufunc_helper(x1, x2, _npi.hypot, _np.hypot, _npi.hypot_scalar, None, out) + + _set_np_symbol_class(_Symbol) diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc index c36423dff9fd..adc9650496f0 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op.cc @@ -182,5 +182,59 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rpower_scalar) .set_attr("FCompute", BinaryScalarOp::Compute) .set_attr("FGradient", ElemwiseGradUseOut{"_backward_rpower_scalar"}); +inline bool HypotOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 1, out_attrs->at(0)); + // check if it is float16, float32 or float64. If not, raise error. + if (in_attrs->at(0) > mshadow::kFloat16) { + // do not support int now. + std::ostringstream os; + os << "Do not support `int` as input.\n"; + throw ::mxnet::op::InferTypeError(os.str(), 0); + } + return out_attrs->at(0) != -1; +} + +// rigister hypot that do not support int here +NNVM_REGISTER_OP(_npi_hypot) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"x1", "x2"}; + }) +.set_attr("FInferShape", BinaryBroadcastShape) +.set_attr("FInferType", HypotOpType) +.set_attr("FCompute", BinaryBroadcastCompute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_hypot"}) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}, {1, 0}}; + }) +.add_argument("x1", "NDArray-or-Symbol", "The input array") +.add_argument("x2", "NDArray-or-Symbol", "The input array"); + +NNVM_REGISTER_OP(_backward_npi_hypot) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInplaceOption", +[](const NodeAttrs& attrs) { + return std::vector > {{0, 1}}; +}) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu b/src/operator/numpy/np_elemwise_broadcast_op.cu index c858b3a4987a..6e24541b090f 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cu +++ b/src/operator/numpy/np_elemwise_broadcast_op.cu @@ -48,6 +48,13 @@ NNVM_REGISTER_OP(_npi_maximum) NNVM_REGISTER_OP(_npi_minimum) .set_attr("FCompute", BinaryBroadcastCompute); +NNVM_REGISTER_OP(_npi_hypot) +.set_attr("FCompute", BinaryBroadcastCompute); + +NNVM_REGISTER_OP(_backward_npi_hypot) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); + NNVM_REGISTER_OP(_npi_add_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); diff --git a/src/operator/tensor/elemwise_binary_scalar_op_extended.cc b/src/operator/tensor/elemwise_binary_scalar_op_extended.cc index 3a687c2aa062..3339e5edcfdd 100644 --- a/src/operator/tensor/elemwise_binary_scalar_op_extended.cc +++ b/src/operator/tensor/elemwise_binary_scalar_op_extended.cc @@ -72,6 +72,7 @@ MXNET_OPERATOR_REGISTER_BINARY(_backward_rpower_scalar) cpu, mshadow_op::rpower_grad>); MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_hypot_scalar) +.add_alias("_npi_hypot_scalar") .set_attr("FCompute", BinaryScalarOp::Compute< cpu, mshadow_op::hypot>) .set_attr("FGradient", ElemwiseGradUseIn{ "_backward_hypot_scalar" }) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index f8d8b4496afc..4dc0423ed312 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -2331,6 +2331,82 @@ def test_math(): for op in ops: run_math(op, shape, dtype, check_value=check_value) + +@with_seed() +@use_np +def test_np_hypot(): + class TestHypot(HybridBlock): + def __init__(self): + super(TestHypot, self).__init__() + + def hybrid_forward(self, F, x1, x2): + return F.np.hypot(x1, x2) + + #Reduce dimension of src to dimention of des. + def dimReduce(src, des): + srcShape = src.shape + desShape = des.shape + if len(desShape) == 0: + return src.sum() + redu = [] + for i, j in zip(range(len(srcShape)-1, -1, -1), range(len(desShape)-1, -1, -1)): + if srcShape[i] != desShape[j] and desShape[j] == 1: + redu.append(i) + if j == 0: + for k in range(0, i): + redu.append(k) + break + if len(redu) > 0: + src = np.reshape(src.sum(axis=tuple(redu)), desShape) + return src + + types = ['float64', 'float32', 'float16'] + for hybridize in [True, False]: + for shape1, shape2 in [[(1,), (1,)], # single elements + [(4, 5), (4, 5)], # normal case + [(3, 2), (3, 2)], # tall matrices + [(), ()], # scalar only + [(3, 0, 2), (3, 0, 2)], # zero-dim + [(3, 4, 5), (4, 1)], # trailing dim broadcasting + [(3, 4, 5), ()], # scalar broadcasting + [(), (1, 2, 3)], # scalar broadcasting + [(4, 3), (4, 1)], # single broadcasting + [(3, 4, 5), (3, 1, 5)] # single broadcasting in the middle + ]: + for oneType in types: + if oneType == 'float16': + rtol = 1e-2 + atol = 1e-2 + else: + rtol=1e-3 + atol = 1e-5 + test_hypot = TestHypot() + if hybridize: + test_hypot.hybridize() + x1 = rand_ndarray(shape1, dtype=oneType).as_np_ndarray() + x2 = rand_ndarray(shape2, dtype=oneType).as_np_ndarray() + x11 = x1.asnumpy() + x21 = x2.asnumpy() + x1.attach_grad() + x2.attach_grad() + np_out = np.hypot(x1.asnumpy(), x2.asnumpy()) + with mx.autograd.record(): + mx_out = test_hypot(x1, x2) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) + mx_out.backward() + np_backward_1 = x11 / np_out + np_backward_2 = x21 / np_out + np_backward_1 = dimReduce(np_backward_1, x11) + np_backward_2 = dimReduce(np_backward_2, x21) + assert_almost_equal(x1.grad.asnumpy(), np_backward_1, rtol=rtol, atol=atol) + assert_almost_equal(x2.grad.asnumpy(), np_backward_2, rtol=rtol, atol=atol) + + mx_out = mx.np.hypot(x1, x2) + np_out = np.hypot(x1.asnumpy(), x2.asnumpy()) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) + + if __name__ == '__main__': import nose nose.runmodule() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 2291bcdb6d3d..905c591ea695 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -880,6 +880,80 @@ def get_new_shape(shape, axis): assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) +@with_seed() +@use_np +def test_np_hypot(): + class TestHypot(HybridBlock): + def __init__(self): + super(TestHypot, self).__init__() + + def hybrid_forward(self, F, x1, x2): + return F.np.hypot(x1, x2) + + def dimReduce(src, des): + srcShape = src.shape + desShape = des.shape + if len(desShape) == 0: + return src.sum() + redu = [] + for i, j in zip(range(len(srcShape)-1, -1, -1), range(len(desShape)-1, -1, -1)): + if srcShape[i] != desShape[j] and desShape[j] == 1: + redu.append(i) + if j == 0: + for k in range(0, i): + redu.append(k) + break + if len(redu) > 0: + src = _np.reshape(src.sum(axis=tuple(redu)), desShape) + return src + + types = ['float64', 'float32', 'float16'] + for hybridize in [True, False]: + for shape1, shape2 in [[(1,), (1,)], # single elements + [(4, 5), (4, 5)], # normal case + [(3, 2), (3, 2)], # tall matrices + [(), ()], # scalar only + [(3, 0, 2), (3, 0, 2)], # zero-dim + [(3, 4, 5), (4, 1)], # trailing dim broadcasting + [(3, 4, 5), ()], # scalar broadcasting + [(), (1, 2, 3)], # scalar broadcasting + [(4, 3), (4, 1)], # single broadcasting + [(3, 4, 5), (3, 1, 5)] # single broadcasting in the middle + ]: + for oneType in types: + if oneType == 'float16': + rtol = 1e-2 + atol = 1e-2 + else: + rtol = 1e-3 + atol = 1e-5 + test_hypot = TestHypot() + if hybridize: + test_hypot.hybridize() + x1 = rand_ndarray(shape1, dtype=oneType).as_np_ndarray() + x2 = rand_ndarray(shape2, dtype=oneType).as_np_ndarray() + x11 = x1.asnumpy() + x21 = x2.asnumpy() + x1.attach_grad() + x2.attach_grad() + np_out = _np.hypot(x1.asnumpy(), x2.asnumpy()) + with mx.autograd.record(): + mx_out = test_hypot(x1, x2) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) + mx_out.backward() + np_backward_1 = x11 / np_out + np_backward_2 = x21 / np_out + np_backward_1 = dimReduce(np_backward_1, x11) + np_backward_2 = dimReduce(np_backward_2, x21) + assert_almost_equal(x1.grad.asnumpy(), np_backward_1, rtol=rtol, atol=atol) + assert_almost_equal(x2.grad.asnumpy(), np_backward_2, rtol=rtol, atol=atol) + + mx_out = np.hypot(x1, x2) + np_out = _np.hypot(x1.asnumpy(), x2.asnumpy()) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) + + if __name__ == '__main__': import nose nose.runmodule()