Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
numpy operator hypot
Browse files Browse the repository at this point in the history
* rebase master
  • Loading branch information
Ying committed Aug 15, 2019
1 parent 40593c6 commit 1783ed1
Show file tree
Hide file tree
Showing 8 changed files with 348 additions and 3 deletions.
51 changes: 50 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ..ndarray import NDArray

__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate']
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'hypot']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -705,3 +705,52 @@ def concatenate(seq, axis=0, out=None):
The concatenated array.
"""
return _npi.concatenate(*seq, dim=axis, out=out)


@set_module('mxnet.ndarray.numpy')
def hypot(x1, x2, out=None):
r"""
hypot(x1, x2, out=None)
Given the "legs" of a right triangle, return its hypotenuse.
Equivalent to ``sqrt(x1**2 + x2**2)``, element-wise. If `x1` or
`x2` is scalar_like (i.e., unambiguously cast-able to a scalar type),
it is broadcast for use with each element of the other argument.
Parameters
----------
x1, x2 : array_like
Leg of the triangle(s).
out : ndarray, None, or tuple of ndarray and None, optional
A location into which the result is stored. If provided, it must have
a shape that the inputs broadcast to. If not provided or `None`,
a freshly-allocated array is returned. A tuple (possible only as a
keyword argument) must have length equal to the number of outputs.
Returns
-------
z : ndarray
The hypotenuse of the triangle(s).
This is a scalar if both `x1` and `x2` are scalars.
Notes
-----
This function differs from the original numpy.arange in the following aspects:
- Only support float16, float32 and float64.
Examples
--------
>>> np.hypot(3*np.ones((3, 3)), 4*np.ones((3, 3)))
array([[ 5., 5., 5.],
[ 5., 5., 5.],
[ 5., 5., 5.]])
Example showing broadcast of scalar_like argument:
>>> np.hypot(3*np.ones((3, 3)), [4])
array([[ 5., 5., 5.],
[ 5., 5., 5.],
[ 5., 5., 5.]])
"""
return _ufunc_helper(x1, x2, _npi.hypot, _np.hypot, _npi.hypot_scalar, None, out)
51 changes: 50 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'add', 'subtract', 'multiply', 'divide',
'mod', 'power', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split',
'concatenate']
'concatenate', 'hypot']


# This function is copied from ndarray.py since pylint
Expand Down Expand Up @@ -1877,3 +1877,52 @@ def concatenate(seq, axis=0, out=None):
The concatenated array.
"""
return _mx_nd_np.concatenate(seq, axis=axis, out=out)


@set_module('mxnet.numpy')
def hypot(x1, x2, out=None):
r"""
hypot(x1, x2, out=None)
Given the "legs" of a right triangle, return its hypotenuse.
Equivalent to ``sqrt(x1**2 + x2**2)``, element-wise. If `x1` or
`x2` is scalar_like (i.e., unambiguously cast-able to a scalar type),
it is broadcast for use with each element of the other argument.
Parameters
----------
x1, x2 : array_like
Leg of the triangle(s).
out : ndarray, None, or tuple of ndarray and None, optional
A location into which the result is stored. If provided, it must have
a shape that the inputs broadcast to. If not provided or `None`,
a freshly-allocated array is returned. A tuple (possible only as a
keyword argument) must have length equal to the number of outputs.
Returns
-------
z : ndarray
The hypotenuse of the triangle(s).
This is a scalar if both `x1` and `x2` are scalars.
Notes
-----
This function differs from the original numpy.arange in the following aspects:
- Only support float16, float32 and float64.
Examples
--------
>>> np.hypot(3*np.ones((3, 3)), 4*np.ones((3, 3)))
array([[ 5., 5., 5.],
[ 5., 5., 5.],
[ 5., 5., 5.]])
Example showing broadcast of scalar_like argument:
>>> np.hypot(3*np.ones((3, 3)), [4])
array([[ 5., 5., 5.],
[ 5., 5., 5.],
[ 5., 5., 5.]])
"""
return _mx_nd_np.hypot(x1, x2, out=out)
37 changes: 36 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from . import _internal as _npi

__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate']
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'hypot']


def _num_outputs(sym):
Expand Down Expand Up @@ -1335,4 +1335,39 @@ def concatenate(seq, axis=0, out=None):
return _npi.concatenate(*seq, dim=axis, out=out)


@set_module('mxnet.symbol.numpy')
def hypot(x1, x2, out=None):
r"""
hypot(x1, x2, out=None)
Given the "legs" of a right triangle, return its hypotenuse.
Equivalent to ``sqrt(x1**2 + x2**2)``, element-wise. If `x1` or
`x2` is scalar_like (i.e., unambiguously cast-able to a scalar type),
it is broadcast for use with each element of the other argument.
Parameters
----------
x1, x2 : array_like
Leg of the triangle(s).
out : ndarray, None, or tuple of ndarray and None, optional
A location into which the result is stored. If provided, it must have
a shape that the inputs broadcast to. If not provided or `None`,
a freshly-allocated array is returned. A tuple (possible only as a
keyword argument) must have length equal to the number of outputs.
Returns
-------
z : ndarray
The hypotenuse of the triangle(s).
This is a scalar if both `x1` and `x2` are scalars.
Notes
-----
This function differs from the original numpy.arange in the following aspects:
- Only support float16, float32 and float64.
"""
return _ufunc_helper(x1, x2, _npi.hypot, _np.hypot, _npi.hypot_scalar, None, out)


_set_np_symbol_class(_Symbol)
54 changes: 54 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,5 +182,59 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rpower_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::rpower>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_rpower_scalar"});

inline bool HypotOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);

TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1));
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
TYPE_ASSIGN_CHECK(*in_attrs, 1, out_attrs->at(0));
// check if it is float16, float32 or float64. If not, raise error.
if (in_attrs->at(0) > mshadow::kFloat16) {
// do not support int now.
std::ostringstream os;
os << "Do not support `int` as input.\n";
throw ::mxnet::op::InferTypeError(os.str(), 0);
}
return out_attrs->at(0) != -1;
}

// rigister hypot that do not support int here
NNVM_REGISTER_OP(_npi_hypot)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"x1", "x2"};
})
.set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape)
.set_attr<nnvm::FInferType>("FInferType", HypotOpType)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::hypot>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_hypot"})
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}};
})
.add_argument("x1", "NDArray-or-Symbol", "The input array")
.add_argument("x2", "NDArray-or-Symbol", "The input array");

NNVM_REGISTER_OP(_backward_npi_hypot)
.set_num_inputs(3)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> > {{0, 1}};
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastBackwardUseIn<cpu, mshadow_op::hypot_grad_left,
mshadow_op::hypot_grad_right>);

} // namespace op
} // namespace mxnet
7 changes: 7 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@ NNVM_REGISTER_OP(_npi_maximum)
NNVM_REGISTER_OP(_npi_minimum)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::minimum>);

NNVM_REGISTER_OP(_npi_hypot)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::hypot>);

NNVM_REGISTER_OP(_backward_npi_hypot)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseIn<gpu, mshadow_op::hypot_grad_left,
mshadow_op::hypot_grad_right>);

NNVM_REGISTER_OP(_npi_add_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, op::mshadow_op::plus>);

Expand Down
1 change: 1 addition & 0 deletions src/operator/tensor/elemwise_binary_scalar_op_extended.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ MXNET_OPERATOR_REGISTER_BINARY(_backward_rpower_scalar)
cpu, mshadow_op::rpower_grad>);

MXNET_OPERATOR_REGISTER_BINARY_SCALAR(_hypot_scalar)
.add_alias("_npi_hypot_scalar")
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<
cpu, mshadow_op::hypot>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{ "_backward_hypot_scalar" })
Expand Down
76 changes: 76 additions & 0 deletions tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2331,6 +2331,82 @@ def test_math():
for op in ops:
run_math(op, shape, dtype, check_value=check_value)


@with_seed()
@use_np
def test_np_hypot():
class TestHypot(HybridBlock):
def __init__(self):
super(TestHypot, self).__init__()

def hybrid_forward(self, F, x1, x2):
return F.np.hypot(x1, x2)

#Reduce dimension of src to dimention of des.
def dimReduce(src, des):
srcShape = src.shape
desShape = des.shape
if len(desShape) == 0:
return src.sum()
redu = []
for i, j in zip(range(len(srcShape)-1, -1, -1), range(len(desShape)-1, -1, -1)):
if srcShape[i] != desShape[j] and desShape[j] == 1:
redu.append(i)
if j == 0:
for k in range(0, i):
redu.append(k)
break
if len(redu) > 0:
src = np.reshape(src.sum(axis=tuple(redu)), desShape)
return src

types = ['float64', 'float32', 'float16']
for hybridize in [True, False]:
for shape1, shape2 in [[(1,), (1,)], # single elements
[(4, 5), (4, 5)], # normal case
[(3, 2), (3, 2)], # tall matrices
[(), ()], # scalar only
[(3, 0, 2), (3, 0, 2)], # zero-dim
[(3, 4, 5), (4, 1)], # trailing dim broadcasting
[(3, 4, 5), ()], # scalar broadcasting
[(), (1, 2, 3)], # scalar broadcasting
[(4, 3), (4, 1)], # single broadcasting
[(3, 4, 5), (3, 1, 5)] # single broadcasting in the middle
]:
for oneType in types:
if oneType == 'float16':
rtol = 1e-2
atol = 1e-2
else:
rtol=1e-3
atol = 1e-5
test_hypot = TestHypot()
if hybridize:
test_hypot.hybridize()
x1 = rand_ndarray(shape1, dtype=oneType).as_np_ndarray()
x2 = rand_ndarray(shape2, dtype=oneType).as_np_ndarray()
x11 = x1.asnumpy()
x21 = x2.asnumpy()
x1.attach_grad()
x2.attach_grad()
np_out = np.hypot(x1.asnumpy(), x2.asnumpy())
with mx.autograd.record():
mx_out = test_hypot(x1, x2)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)
mx_out.backward()
np_backward_1 = x11 / np_out
np_backward_2 = x21 / np_out
np_backward_1 = dimReduce(np_backward_1, x11)
np_backward_2 = dimReduce(np_backward_2, x21)
assert_almost_equal(x1.grad.asnumpy(), np_backward_1, rtol=rtol, atol=atol)
assert_almost_equal(x2.grad.asnumpy(), np_backward_2, rtol=rtol, atol=atol)

mx_out = mx.np.hypot(x1, x2)
np_out = np.hypot(x1.asnumpy(), x2.asnumpy())
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)


if __name__ == '__main__':
import nose
nose.runmodule()
Loading

0 comments on commit 1783ed1

Please sign in to comment.