Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
numpy operator hypot (#15901)
Browse files Browse the repository at this point in the history
* rebase master

* edit test

* add IsIntType to check the input type

* fix error in test

* remove hypot in doc
  • Loading branch information
tingying2020 authored and reminisce committed Sep 23, 2019
1 parent c36819e commit b62d1c2
Show file tree
Hide file tree
Showing 7 changed files with 254 additions and 3 deletions.
49 changes: 48 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around']
'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -3039,3 +3039,50 @@ def arctan2(x1, x2, out=None):
"""
return _ufunc_helper(x1, x2, _npi.arctan2, _np.arctan2,
_npi.arctan2_scalar, _npi.rarctan2_scalar, out=out)


@set_module('mxnet.ndarray.numpy')
def hypot(x1, x2, out=None):
r"""
Given the "legs" of a right triangle, return its hypotenuse.
Equivalent to ``sqrt(x1**2 + x2**2)``, element-wise. If `x1` or
`x2` is scalar_like (i.e., unambiguously cast-able to a scalar type),
it is broadcast for use with each element of the other argument.
Parameters
----------
x1, x2 : array_like
Leg of the triangle(s).
out : ndarray, None, or tuple of ndarray and None, optional
A location into which the result is stored. If provided, it must have
a shape that the inputs broadcast to. If not provided or `None`,
a freshly-allocated array is returned. A tuple (possible only as a
keyword argument) must have length equal to the number of outputs.
Returns
-------
z : ndarray
The hypotenuse of the triangle(s).
This is a scalar if both `x1` and `x2` are scalars.
Notes
-----
This function differs from the original numpy.arange in the following aspects:
- Only support float16, float32 and float64.
Examples
--------
>>> np.hypot(3*np.ones((3, 3)), 4*np.ones((3, 3)))
array([[ 5., 5., 5.],
[ 5., 5., 5.],
[ 5., 5., 5.]])
Example showing broadcast of scalar_like argument:
>>> np.hypot(3*np.ones((3, 3)), [4])
array([[ 5., 5., 5.],
[ 5., 5., 5.],
[ 5., 5., 5.]])
"""
return _ufunc_helper(x1, x2, _npi.hypot, _np.hypot, _npi.hypot_scalar, None, out)
51 changes: 50 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices',
'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2']
'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot']

# Return code for dispatching indexing function call
_NDARRAY_UNSUPPORTED_INDEXING = -1
Expand Down Expand Up @@ -4566,3 +4566,52 @@ def arctan2(x1, x2, out=None):
array([ 1.5707964, -1.5707964])
"""
return _mx_nd_np.arctan2(x1, x2, out=out)


@set_module('mxnet.numpy')
def hypot(x1, x2, out=None):
r"""
hypot(x1, x2, out=None)
Given the "legs" of a right triangle, return its hypotenuse.
Equivalent to ``sqrt(x1**2 + x2**2)``, element-wise. If `x1` or
`x2` is scalar_like (i.e., unambiguously cast-able to a scalar type),
it is broadcast for use with each element of the other argument.
Parameters
----------
x1, x2 : array_like
Leg of the triangle(s).
out : ndarray, None, or tuple of ndarray and None, optional
A location into which the result is stored. If provided, it must have
a shape that the inputs broadcast to. If not provided or `None`,
a freshly-allocated array is returned. A tuple (possible only as a
keyword argument) must have length equal to the number of outputs.
Returns
-------
z : ndarray
The hypotenuse of the triangle(s).
This is a scalar if both `x1` and `x2` are scalars.
Notes
-----
This function differs from the original numpy.arange in the following aspects:
- Only support float16, float32 and float64.
Examples
--------
>>> np.hypot(3*np.ones((3, 3)), 4*np.ones((3, 3)))
array([[ 5., 5., 5.],
[ 5., 5., 5.],
[ 5., 5., 5.]])
Example showing broadcast of scalar_like argument:
>>> np.hypot(3*np.ones((3, 3)), [4])
array([[ 5., 5., 5.],
[ 5., 5., 5.],
[ 5., 5., 5.]])
"""
return _mx_nd_np.hypot(x1, x2, out=out)
35 changes: 34 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around']
'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot']


def _num_outputs(sym):
Expand Down Expand Up @@ -3240,4 +3240,37 @@ def arctan2(x1, x2, out=None):
_npi.arctan2_scalar, _npi.rarctan2_scalar, out=out)


@set_module('mxnet.symbol.numpy')
def hypot(x1, x2, out=None):
r"""
Given the "legs" of a right triangle, return its hypotenuse.
Equivalent to ``sqrt(x1**2 + x2**2)``, element-wise. If `x1` or
`x2` is scalar_like (i.e., unambiguously cast-able to a scalar type),
it is broadcast for use with each element of the other argument.
Parameters
----------
x1, x2 : array_like
Leg of the triangle(s).
out : ndarray, None, or tuple of ndarray and None, optional
A location into which the result is stored. If provided, it must have
a shape that the inputs broadcast to. If not provided or `None`,
a freshly-allocated array is returned. A tuple (possible only as a
keyword argument) must have length equal to the number of outputs.
Returns
-------
z : ndarray
The hypotenuse of the triangle(s).
This is a scalar if both `x1` and `x2` are scalars.
Notes
-----
This function differs from the original numpy.arange in the following aspects:
- Only support float16, float32 and float64.
"""
return _ufunc_helper(x1, x2, _npi.hypot, _np.hypot, _npi.hypot_scalar, None, out)


_set_np_symbol_class(_Symbol)
49 changes: 49 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,5 +214,54 @@ MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rarctan2_scalar)
.set_attr<FCompute>("FCompute<cpu>",
BinaryScalarOp::Backward<cpu, mshadow_op::arctan2_rgrad>);

bool HypotOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);

TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1));
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
TYPE_ASSIGN_CHECK(*in_attrs, 1, out_attrs->at(0));

CHECK(IsFloatType(in_attrs->at(0))) << "Do not support `int` as input.\n";
return out_attrs->at(0) != -1;
}

// rigister hypot that do not support int here
NNVM_REGISTER_OP(_npi_hypot)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"x1", "x2"};
})
.set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape)
.set_attr<nnvm::FInferType>("FInferType", HypotOpType)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::hypot>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_hypot"})
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}};
})
.add_argument("x1", "NDArray-or-Symbol", "The input array")
.add_argument("x2", "NDArray-or-Symbol", "The input array");

NNVM_REGISTER_OP(_backward_npi_hypot)
.set_num_inputs(3)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> > {{0, 1}};
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastBackwardUseIn<cpu, mshadow_op::hypot_grad_left,
mshadow_op::hypot_grad_right>);

} // namespace op
} // namespace mxnet
6 changes: 6 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ NNVM_REGISTER_OP(_npi_arctan2)
NNVM_REGISTER_OP(_backward_npi_arctan2)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseIn<gpu, mshadow_op::arctan2_grad,
mshadow_op::arctan2_rgrad>);
NNVM_REGISTER_OP(_npi_hypot)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::hypot>);

NNVM_REGISTER_OP(_backward_npi_hypot)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseIn<gpu, mshadow_op::hypot_grad_left,
mshadow_op::hypot_grad_right>);

NNVM_REGISTER_OP(_npi_add_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, op::mshadow_op::plus>);
Expand Down
1 change: 1 addition & 0 deletions src/operator/tensor/elemwise_binary_scalar_op_extended.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ MXNET_OPERATOR_REGISTER_BINARY(_backward_rpower_scalar)
cpu, mshadow_op::rpower_grad>);

MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_hypot_scalar)
.add_alias("_npi_hypot_scalar")
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<
cpu, mshadow_op::hypot>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{ "_backward_hypot_scalar" })
Expand Down
66 changes: 66 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2479,6 +2479,72 @@ def hybrid_forward(self, F, x):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol, atol)


@with_seed()
@use_np
def test_np_hypot():
class TestHypot(HybridBlock):
def __init__(self):
super(TestHypot, self).__init__()

def hybrid_forward(self, F, x1, x2):
return F.np.hypot(x1, x2)

def dimReduce(src, des):
srcShape = src.shape
desShape = des.shape
if len(desShape) == 0:
return src.sum()
redu = []
for i, j in zip(range(len(srcShape)-1, -1, -1), range(len(desShape)-1, -1, -1)):
if srcShape[i] != desShape[j] and desShape[j] == 1:
redu.append(i)
if j == 0:
for k in range(0, i):
redu.append(k)
break
if len(redu) > 0:
src = _np.reshape(src.sum(axis=tuple(redu)), desShape)
return src

types = ['float64', 'float32', 'float16']
for hybridize in [True, False]:
for shape1, shape2 in [[(3, 2), (3, 2)], # tall matrices
[(), ()], # scalar only
[(3, 0, 2), (3, 0, 2)], # zero-dim
[(3, 4, 5), (4, 1)], # trailing dim broadcasting
[(3, 4, 5), ()], # scalar broadcasting
[(), (1, 2, 3)], # scalar broadcasting
]:
for oneType in types:
rtol = 1e-2 if oneType == 'float16' else 1e-3
atol = 1e-2 if oneType == 'float16' else 1e-5
test_hypot = TestHypot()
if hybridize:
test_hypot.hybridize()
x1 = rand_ndarray(shape1, dtype=oneType).as_np_ndarray()
x2 = rand_ndarray(shape2, dtype=oneType).as_np_ndarray()
x11 = x1.asnumpy()
x21 = x2.asnumpy()
x1.attach_grad()
x2.attach_grad()
np_out = _np.hypot(x1.asnumpy(), x2.asnumpy())
with mx.autograd.record():
mx_out = test_hypot(x1, x2)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)
mx_out.backward()
np_backward_1 = x11 / np_out
np_backward_2 = x21 / np_out
np_backward_1 = dimReduce(np_backward_1, x11)
np_backward_2 = dimReduce(np_backward_2, x21)
assert_almost_equal(x1.grad.asnumpy(), np_backward_1, rtol=rtol, atol=atol)
assert_almost_equal(x2.grad.asnumpy(), np_backward_2, rtol=rtol, atol=atol)

mx_out = np.hypot(x1, x2)
np_out = _np.hypot(x1.asnumpy(), x2.asnumpy())
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit b62d1c2

Please sign in to comment.