From 8af1b5756d965128e19b3c6535a9e3fc2f457013 Mon Sep 17 00:00:00 2001 From: tingying Date: Mon, 23 Sep 2019 13:38:37 +0800 Subject: [PATCH] numpy operator arctan2 (#15890) * change the test code * add @use_np in test code * only support float16, float32 and float64. * fix format error * remove redundant backslash * change wrapper in symbol * delete gpu test * edit test * change infer type * remove redundant **kwargs * change atol and rtol in test * edit test shape --- python/mxnet/ndarray/numpy/_op.py | 92 ++++++++++++++++++- python/mxnet/numpy/multiarray.py | 91 +++++++++++++++++- python/mxnet/symbol/numpy/_symbol.py | 72 ++++++++++++++- src/operator/math_functions-inl.h | 2 + src/operator/mshadow_op.h | 10 ++ .../numpy/np_elemwise_broadcast_op.cc | 70 ++++++++++++++ .../numpy/np_elemwise_broadcast_op.cu | 19 ++++ src/operator/operator_tune.cc | 5 + tests/python/unittest/test_numpy_op.py | 67 ++++++++++++++ 9 files changed, 420 insertions(+), 8 deletions(-) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 2cdfff173b8e..197bae614745 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -28,9 +28,9 @@ from ..ndarray import NDArray __all__ = ['zeros', 'ones', 'full', 'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', - 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', - 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', - 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', + 'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', + 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', + 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', @@ -2953,3 +2953,89 @@ def around(x, decimals=0, out=None, **kwargs): return _npi.around(x, decimals, out=out, **kwargs) else: raise TypeError('type {} not supported'.format(str(type(x)))) + + +@set_module('mxnet.ndarray.numpy') +def arctan2(x1, x2, out=None): + r""" + arctan2(x1, x2, out=None) + + Element-wise arc tangent of ``x1/x2`` choosing the quadrant correctly. + + The quadrant (i.e., branch) is chosen so that ``arctan2(x1, x2)`` is + the signed angle in radians between the ray ending at the origin and + passing through the point (1,0), and the ray ending at the origin and + passing through the point (`x2`, `x1`). (Note the role reversal: the + "`y`-coordinate" is the first function parameter, the "`x`-coordinate" + is the second.) By IEEE convention, this function is defined for + `x2` = +/-0 and for either or both of `x1` and `x2` = +/-inf (see + Notes for specific values). + + This function is not defined for complex-valued arguments; for the + so-called argument of complex values, use `angle`. + + Parameters + ---------- + x1 : ndarray or scalar + `y`-coordinates. + x2 : ndarray or scalar + `x`-coordinates. `x2` must be broadcastable to match the shape of + `x1` or vice versa. + out : ndarray or None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + + Returns + ------- + out : ndarray or scalar + Array of angles in radians, in the range ``[-pi, pi]``. This is a scalar if + `x1` and `x2` are scalars. + + Notes + ----- + *arctan2* is identical to the `atan2` function of the underlying + C library. The following special values are defined in the C + standard: [1]_ + + ====== ====== ================ + `x1` `x2` `arctan2(x1,x2)` + ====== ====== ================ + +/- 0 +0 +/- 0 + +/- 0 -0 +/- pi + > 0 +/-inf +0 / +pi + < 0 +/-inf -0 / -pi + +/-inf +inf +/- (pi/4) + +/-inf -inf +/- (3*pi/4) + ====== ====== ================ + + Note that +0 and -0 are distinct floating point numbers, as are +inf + and -inf. + + This function differs from the original numpy.arange in the following aspects: + - Only support float16, float32 and float64. + + References + ---------- + .. [1] ISO/IEC standard 9899:1999, "Programming language C." + + Examples + -------- + Consider four points in different quadrants: + + >>> x = np.array([-1, +1, +1, -1]) + >>> y = np.array([-1, -1, +1, +1]) + >>> np.arctan2(y, x) * 180 / np.pi + array([-135., -45., 45., 135.]) + + Note the order of the parameters. `arctan2` is defined also when `x2` = 0 + and at several other special points, obtaining values in + the range ``[-pi, pi]``: + + >>> x = np.array([1, -1]) + >>> y = np.array([0, 0]) + >>> np.arctan2(x, y) + array([ 1.5707964, -1.5707964]) + """ + return _ufunc_helper(x1, x2, _npi.arctan2, _np.arctan2, + _npi.arctan2_scalar, _npi.rarctan2_scalar, out=out) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index e0c7a67d4715..0cd90365be02 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -47,13 +47,13 @@ from ..ndarray.numpy import _internal as _npi __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'full', 'add', 'subtract', 'multiply', 'divide', - 'mod', 'remainder', 'power', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', - 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', + 'mod', 'remainder', 'power', 'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', + 'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', - 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around'] + 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2'] # Return code for dispatching indexing function call _NDARRAY_UNSUPPORTED_INDEXING = -1 @@ -4481,3 +4481,88 @@ def around(x, decimals=0, out=None, **kwargs): array([ 0, 0, 0, 10]) """ return _mx_nd_np.around(x, decimals, out=out, **kwargs) + + +@set_module('mxnet.numpy') +def arctan2(x1, x2, out=None): + r""" + arctan2(x1, x2, out=None) + + Element-wise arc tangent of ``x1/x2`` choosing the quadrant correctly. + + The quadrant (i.e., branch) is chosen so that ``arctan2(x1, x2)`` is + the signed angle in radians between the ray ending at the origin and + passing through the point (1,0), and the ray ending at the origin and + passing through the point (`x2`, `x1`). (Note the role reversal: the + "`y`-coordinate" is the first function parameter, the "`x`-coordinate" + is the second.) By IEEE convention, this function is defined for + `x2` = +/-0 and for either or both of `x1` and `x2` = +/-inf (see + Notes for specific values). + + This function is not defined for complex-valued arguments; for the + so-called argument of complex values, use `angle`. + + Parameters + ---------- + x1 : ndarray or scalar + `y`-coordinates. + x2 : ndarray or scalar + `x`-coordinates. `x2` must be broadcastable to match the shape of + `x1` or vice versa. + out : ndarray or None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + + Returns + ------- + out : ndarray or scalar + Array of angles in radians, in the range ``[-pi, pi]``. This is a scalar if + `x1` and `x2` are scalars. + + Notes + ----- + *arctan2* is identical to the `atan2` function of the underlying + C library. The following special values are defined in the C + standard: [1]_ + + ====== ====== ================ + `x1` `x2` `arctan2(x1,x2)` + ====== ====== ================ + +/- 0 +0 +/- 0 + +/- 0 -0 +/- pi + > 0 +/-inf +0 / +pi + < 0 +/-inf -0 / -pi + +/-inf +inf +/- (pi/4) + +/-inf -inf +/- (3*pi/4) + ====== ====== ================ + + Note that +0 and -0 are distinct floating point numbers, as are +inf + and -inf. + + This function differs from the original numpy.arange in the following aspects: + - Only support float16, float32 and float64. + + References + ---------- + .. [1] ISO/IEC standard 9899:1999, "Programming language C." + + Examples + -------- + Consider four points in different quadrants: + + >>> x = np.array([-1, +1, +1, -1]) + >>> y = np.array([-1, -1, +1, +1]) + >>> np.arctan2(y, x) * 180 / np.pi + array([-135., -45., 45., 135.]) + + Note the order of the parameters. `arctan2` is defined also when `x2` = 0 + and at several other special points, obtaining values in + the range ``[-pi, pi]``: + + >>> x = np.array([1, -1]) + >>> y = np.array([0, 0]) + >>> np.arctan2(x, y) + array([ 1.5707964, -1.5707964]) + """ + return _mx_nd_np.arctan2(x1, x2, out=out) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 1bfbba209989..94a4a37d273e 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -29,8 +29,8 @@ from .._internal import _set_np_symbol_class from . import _internal as _npi -__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'sin', - 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp', +__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'remainder', 'power', 'arctan2', + 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', @@ -3172,4 +3172,72 @@ def around(x, decimals=0, out=None, **kwargs): raise TypeError('type {} not supported'.format(str(type(x)))) +@set_module('mxnet.symbol.numpy') +def arctan2(x1, x2, out=None): + r""" + arctan2(x1, x2, out=None) + + Element-wise arc tangent of ``x1/x2`` choosing the quadrant correctly. + + The quadrant (i.e., branch) is chosen so that ``arctan2(x1, x2)`` is + the signed angle in radians between the ray ending at the origin and + passing through the point (1,0), and the ray ending at the origin and + passing through the point (`x2`, `x1`). (Note the role reversal: the + "`y`-coordinate" is the first function parameter, the "`x`-coordinate" + is the second.) By IEEE convention, this function is defined for + `x2` = +/-0 and for either or both of `x1` and `x2` = +/-inf (see + Notes for specific values). + + This function is not defined for complex-valued arguments; for the + so-called argument of complex values, use `angle`. + + Parameters + ---------- + x1 : _Symbol or scalar + `y`-coordinates. + x2 : _Symbol or scalar + `x`-coordinates. `x2` must be broadcastable to match the shape of + `x1` or vice versa. + out : _Symbol or None, optional + A location into which the result is stored. If provided, it must have + a shape that the inputs broadcast to. If not provided or `None`, + a freshly-allocated array is returned. + + Returns + ------- + out : _Symbol or scalar + Array of angles in radians, in the range ``[-pi, pi]``. This is a scalar if + `x1` and `x2` are scalars. + + Notes + ----- + *arctan2* is identical to the `atan2` function of the underlying + C library. The following special values are defined in the C + standard: [1]_ + + ====== ====== ================ + `x1` `x2` `arctan2(x1,x2)` + ====== ====== ================ + +/- 0 +0 +/- 0 + +/- 0 -0 +/- pi + > 0 +/-inf +0 / +pi + < 0 +/-inf -0 / -pi + +/-inf +inf +/- (pi/4) + +/-inf -inf +/- (3*pi/4) + ====== ====== ================ + + Note that +0 and -0 are distinct floating point numbers, as are +inf + and -inf. + + This function differs from the original numpy.arange in the following aspects: + - Only support float16, float32 and float64. + + References + ---------- + .. [1] ISO/IEC standard 9899:1999, "Programming language C." + """ + return _ufunc_helper(x1, x2, _npi.arctan2, _np.arctan2, + _npi.arctan2_scalar, _npi.rarctan2_scalar, out=out) + + _set_np_symbol_class(_Symbol) diff --git a/src/operator/math_functions-inl.h b/src/operator/math_functions-inl.h index 45d74a62d8bd..5f95654cad37 100644 --- a/src/operator/math_functions-inl.h +++ b/src/operator/math_functions-inl.h @@ -125,6 +125,8 @@ MXNET_BINARY_MATH_FUNC(hypot) MXNET_BINARY_MATH_FUNC(pow) +MXNET_BINARY_MATH_FUNC(atan2) + template MSHADOW_XINLINE float id(DType a) { return static_cast(a); diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index f3d24b2c119e..6261638c03ec 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -322,6 +322,16 @@ MXNET_BINARY_MATH_OP(rpower, math::pow(b, a)); MXNET_BINARY_MATH_OP(rpower_grad, math::id(a) * math::log(b)); +MXNET_BINARY_MATH_OP(arctan2, math::atan2(a, b)); + +MXNET_BINARY_MATH_OP(arctan2_grad, math::id(b) / (math::id(a * a + b * b))); + +MXNET_BINARY_MATH_OP(arctan2_rgrad, -math::id(a) / (math::id(a * a + b * b))); + +MXNET_BINARY_MATH_OP(rarctan2, math::atan2(b, a)); + +MXNET_BINARY_MATH_OP(rarctan2_grad, math::id(a) / (math::id(a * a + b * b))); + MXNET_UNARY_MATH_OP_NC(nt, a != DType(0) ? DType(0) : DType(1)); MXNET_BINARY_MATH_OP_NC(ge, a >= b ? DType(1) : DType(0)); diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cc b/src/operator/numpy/np_elemwise_broadcast_op.cc index a9254e8a7d02..f9293ee35a60 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op.cc @@ -144,5 +144,75 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_backward_npi_rcopysign_scalar) .set_attr("FCompute", BinaryScalarOp::Backward); +inline bool IsFloatType(const int dtype) { + return (dtype == mshadow::kFloat16 || + dtype == mshadow::kFloat32 || + dtype == mshadow::kFloat64); +} + +inline bool Arctan2OpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 2U); + CHECK_EQ(out_attrs->size(), 1U); + + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 1, out_attrs->at(0)); + // check if it is float16, float32 or float64. If not, raise error. + CHECK(IsFloatType(in_attrs->at(0))) << "Do not support `int` as input.\n"; + return out_attrs->at(0) != -1; +} + +NNVM_REGISTER_OP(_npi_arctan2) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"x1", "x2"}; + }) +.set_attr("FInferShape", BinaryBroadcastShape) +.set_attr("FInferType", Arctan2OpType) +.set_attr("FCompute", BinaryBroadcastCompute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_arctan2"}) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}}; + }) +.add_argument("x1", "NDArray-or-Symbol", "The input array") +.add_argument("x2", "NDArray-or-Symbol", "The input array"); + +NNVM_REGISTER_OP(_backward_npi_arctan2) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_arctan2_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_arctan2_scalar"}); + +MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rarctan2_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_npi_rarctan2_scalar"}); + +MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_arctan2_scalar) +.add_argument("scalar", "float", "scalar value") +.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.set_attr("FCompute", + BinaryScalarOp::Backward); + +MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rarctan2_scalar) +.add_argument("scalar", "float", "scalar value") +.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); }) +.set_attr("FCompute", + BinaryScalarOp::Backward); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_elemwise_broadcast_op.cu b/src/operator/numpy/np_elemwise_broadcast_op.cu index ecf8e8531334..ab76e5c6fd7d 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op.cu +++ b/src/operator/numpy/np_elemwise_broadcast_op.cu @@ -49,6 +49,13 @@ NNVM_REGISTER_OP(_backward_npi_copysign) .set_attr("FCompute", BinaryBroadcastBackwardUseIn); +NNVM_REGISTER_OP(_npi_arctan2) +.set_attr("FCompute", BinaryBroadcastCompute); + +NNVM_REGISTER_OP(_backward_npi_arctan2) +.set_attr("FCompute", BinaryBroadcastBackwardUseIn); + NNVM_REGISTER_OP(_npi_add_scalar) .set_attr("FCompute", BinaryScalarOp::Compute); @@ -87,5 +94,17 @@ NNVM_REGISTER_OP(_backward_npi_rcopysign_scalar) .set_attr("FCompute", BinaryScalarOp::Backward); +NNVM_REGISTER_OP(_npi_arctan2_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + +NNVM_REGISTER_OP(_backward_npi_arctan2_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + +NNVM_REGISTER_OP(_npi_rarctan2_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + +NNVM_REGISTER_OP(_backward_npi_rarctan2_scalar) +.set_attr("FCompute", BinaryScalarOp::Compute); + } // namespace op } // namespace mxnet diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index 51595254cafc..1d644386cdbb 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -333,6 +333,11 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rcopysign); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::copysign_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::copysign_rgrad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rcopysign_grad); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::arctan2); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rarctan2); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arctan2_grad); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rarctan2_grad); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::arctan2_rgrad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::xelu_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gelu_grad); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::prelu_grad); // NOLINT() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 1324af03d1b5..3d30012750e9 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2380,6 +2380,73 @@ def hybrid_forward(self, F, x): assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) +@with_seed() +@use_np +def test_np_arctan2(): + class TestArctan2(HybridBlock): + def __init__(self): + super(TestArctan2, self).__init__() + + def hybrid_forward(self, F, x1, x2): + return F.np.arctan2(x1, x2) + + # Reduce dimension of src to dimension of des. + def dimReduce(src, des): + srcShape = src.shape + desShape = des.shape + if len(desShape) == 0: + return src.sum() + redu = [] + for i, j in zip(range(len(srcShape)-1, -1, -1), range(len(desShape)-1, -1, -1)): + if srcShape[i] != desShape[j] and desShape[j] == 1: + redu.append(i) + if j == 0: + for k in range(0, i): + redu.append(k) + break + if len(redu) > 0: + src = _np.reshape(src.sum(axis=tuple(redu)), desShape) + return src + + types = ['float64', 'float32', 'float16'] + for hybridize in [True, False]: + for shape1, shape2 in [[(3, 2), (3, 2)], # tall matrices + [(), ()], # scalar only + [(3, 0, 2), (3, 0, 2)], # zero-dim + [(3, 4, 5), (4, 1)], # trailing dim broadcasting + [(3, 4, 5), ()], # scalar broadcasting + [(), (1, 2, 3)], # scalar broadcasting + ]: + for oneType in types: + rtol = 1e-2 if oneType == 'float16' else 1e-3 + atol = 1e-2 if oneType == 'float16' else 1e-5 + test_arctan2 = TestArctan2() + if hybridize: + test_arctan2.hybridize() + x1 = rand_ndarray(shape1, dtype=oneType).as_np_ndarray() + x2 = rand_ndarray(shape2, dtype=oneType).as_np_ndarray() + x11 = x1.asnumpy() + x21 = x2.asnumpy() + x1.attach_grad() + x2.attach_grad() + np_out = _np.arctan2(x1.asnumpy(), x2.asnumpy()) + with mx.autograd.record(): + mx_out = test_arctan2(x1, x2) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) + mx_out.backward() + np_backward_1 = x21 / (x11 * x11 + x21 * x21) + np_backward_2 = -1 * x11 / (x11 * x11 + x21 * x21) + np_backward_1 = dimReduce(np_backward_1, x11) + np_backward_2 = dimReduce(np_backward_2, x21) + assert_almost_equal(x1.grad.asnumpy(), np_backward_1, rtol=rtol, atol=atol) + assert_almost_equal(x2.grad.asnumpy(), np_backward_2, rtol=rtol, atol=atol) + + mx_out = np.arctan2(x1, x2) + np_out = _np.arctan2(x1.asnumpy(), x2.asnumpy()) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) + + if __name__ == '__main__': import nose nose.runmodule()