Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[Numpy] Numpy copysign (#15851)
Browse files Browse the repository at this point in the history
* add numpy compatible copysign

* fix scalar op registration error

* add test
  • Loading branch information
hzfan authored and sxjscience committed Sep 15, 2019
1 parent e9e267e commit 90091b1
Show file tree
Hide file tree
Showing 8 changed files with 316 additions and 3 deletions.
53 changes: 52 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices']
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -2432,3 +2432,54 @@ def indices(dimensions, dtype=_np.int32, ctx=None):
else:
raise ValueError("The dimensions must be sequence of ints")
# pylint: enable=redefined-outer-name


@set_module('mxnet.ndarray.numpy')
def copysign(x1, x2, out=None):
r"""copysign(x1, x2, out=None)
Change the sign of x1 to that of x2, element-wise.
If `x2` is a scalar, its sign will be copied to all elements of `x1`.
Parameters
----------
x1 : ndarray or scalar
Values to change the sign of.
x2 : ndarray or scalar
The sign of `x2` is copied to `x1`.
out : ndarray or None, optional
A location into which the result is stored. It must be of the
right shape and right type to hold the output. If not provided
or `None`,a freshly-allocated array is returned.
Returns
-------
out : ndarray or scalar
The values of `x1` with the sign of `x2`.
This is a scalar if both `x1` and `x2` are scalars.
Notes
-------
This function differs from the original `numpy.copysign
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.copysign.html>`_ in
the following aspects:
- ``where`` param is not supported.
Examples
--------
>>> np.copysign(1.3, -1)
-1.3
>>> 1/np.copysign(0, 1)
inf
>>> 1/np.copysign(0, -1)
-inf
>>> a = np.array([-1, 0, 1])
>>> np.copysign(a, -1.1)
array([-1., -0., -1.])
>>> np.copysign(a, np.arange(3)-1)
array([-1., 0., 1.])
"""
return _ufunc_helper(x1, x2, _npi.copysign, _np.copysign, _npi.copysign_scalar, _npi.rcopysign_scalar, out)
53 changes: 52 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices']
'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign']

# Return code for dispatching indexing function call
_NDARRAY_UNSUPPORTED_INDEXING = -1
Expand Down Expand Up @@ -3935,3 +3935,54 @@ def indices(dimensions, dtype=_np.int32, ctx=None):
"""
return _mx_nd_np.indices(dimensions=dimensions, dtype=dtype, ctx=ctx)
# pylint: enable=redefined-outer-name


@set_module('mxnet.numpy')
def copysign(x1, x2, out=None):
r"""copysign(x1, x2, out=None)
Change the sign of x1 to that of x2, element-wise.
If `x2` is a scalar, its sign will be copied to all elements of `x1`.
Parameters
----------
x1 : ndarray or scalar
Values to change the sign of.
x2 : ndarray or scalar
The sign of `x2` is copied to `x1`.
out : ndarray or None, optional
A location into which the result is stored. It must be of the
right shape and right type to hold the output. If not provided
or `None`,a freshly-allocated array is returned.
Returns
-------
out : ndarray or scalar
The values of `x1` with the sign of `x2`.
This is a scalar if both `x1` and `x2` are scalars.
Notes
-------
This function differs from the original `numpy.copysign
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.copysign.html>`_ in
the following aspects:
- ``where`` param is not supported.
Examples
--------
>>> np.copysign(1.3, -1)
-1.3
>>> 1/np.copysign(0, 1)
inf
>>> 1/np.copysign(0, -1)
-inf
>>> a = np.array([-1, 0, 1])
>>> np.copysign(a, -1.1)
array([-1., -0., -1.])
>>> np.copysign(a, np.arange(3)-1)
array([-1., 0., 1.])
"""
return _mx_nd_np.copysign(x1, x2, out=out)
36 changes: 35 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices']
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign']


def _num_outputs(sym):
Expand Down Expand Up @@ -2744,4 +2744,38 @@ def indices(dimensions, dtype=_np.int32, ctx=None):
# pylint: enable=redefined-outer-name


@set_module('mxnet.symbol.numpy')
def copysign(x1, x2, out=None):
r"""copysign(x1, x2, out=None)
Change the sign of x1 to that of x2, element-wise.
If `x2` is a scalar, its sign will be copied to all elements of `x1`.
Parameters
----------
x1 : _Symbol or scalar
Values to change the sign of.
x2 : _Symbol or scalar
The sign of `x2` is copied to `x1`.
out : _Symbol or None
Dummy parameter to keep the consistency with the ndarray counterpart.
Returns
-------
out : _Symbol
The values of `x1` with the sign of `x2`.
This is a scalar if both `x1` and `x2` are scalars.
Notes
-------
This function differs from the original `numpy.copysign
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.copysign.html>`_ in
the following aspects:
- ``where`` param is not supported.
"""
return _ufunc_helper(x1, x2, _npi.copysign, _np.copysign, _npi.copysign_scalar, _npi.rcopysign_scalar, out)


_set_np_symbol_class(_Symbol)
10 changes: 10 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,16 @@ MXNET_BINARY_MATH_OP(rdiv, math::id(b) / math::id(a));

MXNET_BINARY_MATH_OP(rdiv_grad, -math::id(b) / math::sqr(a));

MXNET_BINARY_MATH_OP(copysign, (a >= 0 && b >= 0) || (a < 0 && b < 0) ? a : -a);

MXNET_BINARY_MATH_OP(copysign_grad, (a >= 0 && b >= 0) || (a < 0 && b < 0) ? 1: -1);

MXNET_BINARY_MATH_OP(copysign_rgrad, 0);

MXNET_BINARY_MATH_OP(rcopysign, (b >= 0 && a >= 0) || (b < 0 && a < 0) ? b : -b);

MXNET_BINARY_MATH_OP(rcopysign_grad, 0);

struct mod : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static typename enable_if<!is_unsigned<DType>::value, DType>::type
Expand Down
36 changes: 36 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,26 @@ MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_power)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::power>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_power"});

MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_copysign)
.describe(R"code()code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::copysign>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_copysign"});

NNVM_REGISTER_OP(_backward_npi_copysign)
.set_num_inputs(3)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 1}};
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastBackwardUseIn<cpu, mshadow_op::copysign_grad,
mshadow_op::copysign_rgrad>);

MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_add_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, op::mshadow_op::plus>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_copy"});
Expand Down Expand Up @@ -108,5 +128,21 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rpower_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::rpower>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_rpower_scalar"});

MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_copysign_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::copysign>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_copysign_scalar"});

MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rcopysign_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::rcopysign>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_rcopysign_scalar"});

MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_backward_npi_copysign_scalar)
.set_attr<FCompute>("FCompute<cpu>",
BinaryScalarOp::Backward<cpu, mshadow_op::copysign_grad>);

MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_backward_npi_rcopysign_scalar)
.set_attr<FCompute>("FCompute<cpu>",
BinaryScalarOp::Backward<cpu, mshadow_op::rcopysign_grad>);

} // namespace op
} // namespace mxnet
21 changes: 21 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ NNVM_REGISTER_OP(_npi_mod)
NNVM_REGISTER_OP(_npi_power)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::power>);

NNVM_REGISTER_OP(_npi_copysign)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::copysign>);

NNVM_REGISTER_OP(_backward_npi_copysign)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseIn<gpu, mshadow_op::copysign_grad,
mshadow_op::copysign_rgrad>);

NNVM_REGISTER_OP(_npi_add_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, op::mshadow_op::plus>);

Expand All @@ -66,5 +73,19 @@ NNVM_REGISTER_OP(_npi_power_scalar)
NNVM_REGISTER_OP(_npi_rpower_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::rpower>);

NNVM_REGISTER_OP(_npi_copysign_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::copysign>);

NNVM_REGISTER_OP(_npi_rcopysign_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::rcopysign>);

NNVM_REGISTER_OP(_backward_npi_copysign_scalar)
.set_attr<FCompute>("FCompute<gpu>",
BinaryScalarOp::Backward<gpu, mshadow_op::copysign_grad>);

NNVM_REGISTER_OP(_backward_npi_rcopysign_scalar)
.set_attr<FCompute>("FCompute<gpu>",
BinaryScalarOp::Backward<gpu, mshadow_op::rcopysign_grad>);

} // namespace op
} // namespace mxnet
5 changes: 5 additions & 0 deletions src/operator/operator_tune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,11 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::elu); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rpower_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_rgrad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::copysign); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rcopysign); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::copysign_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::copysign_rgrad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rcopysign_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::xelu_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gelu_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::prelu_grad); // NOLINT()
Expand Down
105 changes: 105 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1853,6 +1853,111 @@ def hybrid_forward(self, F, x):
assert_almost_equal(mx_ret.asnumpy(), np_ret, atol=1e-5, rtol=1e-4)


@with_seed()
@use_np
def test_np_copysign():
class TestCopysign(HybridBlock):
def __init__(self):
super(TestCopysign, self).__init__()

def hybrid_forward(self, F, a1, a2):
return F.np.copysign(a1, a2)

def get_grad(a1, a2):
sign = _np.logical_or(_np.logical_and(a1 < 0, a2 < 0),
_np.logical_and(a1 >= 0, a2 >= 0))
sign = 2 * sign.astype(int) - 1
sign = sign.reshape(-1, *a1.shape)
sign = _np.sum(sign, axis=0)
return sign, _np.zeros_like(a2)

def get_grad_left(a1, a2):
sign = _np.logical_or(_np.logical_and(a1 < 0, a2 < 0),
_np.logical_and(a1 >= 0, a2 >= 0))
sign = 2 * sign.astype(int) - 1
sign = sign.reshape(a1.shape)
return sign

def get_grad_right(a1, a2):
return _np.zeros_like(a2)

shapes = [
(),
(1),
(2, 1),
(3, 2, 1),
(4, 3, 2, 1),
(2, 4, 3, 2, 1)
]
types = ['float16', 'float32', 'float64', 'int8', 'int32', 'int64']
for a1shape in shapes:
for a2shape in shapes:
for hybridize in [True, False]:
for dtype in types:
test_copysign = TestCopysign()
if hybridize:
test_copysign.hybridize()
rtol = 1e-3
atol = 1e-5
a1_np = _np.array(_np.random.uniform(-1.0, 1.0, a1shape), dtype=dtype)
a2_np = _np.array(_np.random.uniform(-1.0, 1.0, a2shape), dtype=dtype)
a1 = np.array(a1_np, dtype=dtype)
a2 = np.array(a2_np, dtype=dtype)
a1.attach_grad()
a2.attach_grad()
expected_np = _np.copysign(a1_np, a2_np)
with mx.autograd.record():
mx_out = test_copysign(a1, a2)
assert mx_out.shape == expected_np.shape
assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol)

# Test gradient
mx_out.backward()
a1_grad, a2_grad = get_grad(a1_np, a2_np)
assert_almost_equal(a1.grad.asnumpy(), a1_grad, rtol=rtol, atol=atol)
assert_almost_equal(a2.grad.asnumpy(), a2_grad, rtol=rtol, atol=atol)

# Test imperative once again
mx_out = np.copysign(a1, a2)
expected_np = _np.copysign(a1_np, a2_np)
assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol)

types = ['float16', 'float32', 'float64']
for x_shape in shapes:
for dtype in types:
# Test left
x_np = _np.array(_np.random.uniform(-2.0, 2.0, x_shape), dtype=dtype)
scalar = _np.random.uniform(-2.0, 2.0)
x = np.array(x_np, dtype=dtype)
x.attach_grad()
expected_np = _np.copysign(x_np, scalar)
with mx.autograd.record():
mx_out = np.copysign(x, scalar)
assert mx_out.shape == expected_np.shape
assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol)

# Test gradient
mx_out.backward()
x_grad = get_grad_left(x_np, scalar)
assert_almost_equal(x.grad.asnumpy(), x_grad, rtol=rtol, atol=atol)

# Test right
x_np = _np.array(_np.random.uniform(-2.0, 2.0, x_shape), dtype=dtype)
scalar = _np.random.uniform(-2.0, 2.0)
x = np.array(x_np, dtype=dtype)
x.attach_grad()
expected_np = _np.copysign(scalar, x_np)
with mx.autograd.record():
mx_out = np.copysign(scalar, x)
assert mx_out.shape == expected_np.shape
assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol)

# Test gradient
mx_out.backward()
x_grad = get_grad_right(scalar, x_np)
assert_almost_equal(x.grad.asnumpy(), x_grad, rtol=rtol, atol=atol)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 90091b1

Please sign in to comment.