Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Numpy] Numpy copysign #15851

Merged
merged 3 commits into from
Sep 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 52 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices']
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -2432,3 +2432,54 @@ def indices(dimensions, dtype=_np.int32, ctx=None):
else:
raise ValueError("The dimensions must be sequence of ints")
# pylint: enable=redefined-outer-name


@set_module('mxnet.ndarray.numpy')
def copysign(x1, x2, out=None):
r"""copysign(x1, x2, out=None)

Change the sign of x1 to that of x2, element-wise.

If `x2` is a scalar, its sign will be copied to all elements of `x1`.

Parameters
----------
x1 : ndarray or scalar
Values to change the sign of.
x2 : ndarray or scalar
The sign of `x2` is copied to `x1`.
out : ndarray or None, optional
A location into which the result is stored. It must be of the
right shape and right type to hold the output. If not provided
or `None`,a freshly-allocated array is returned.

Returns
-------
out : ndarray or scalar
The values of `x1` with the sign of `x2`.
This is a scalar if both `x1` and `x2` are scalars.

Notes
-------
This function differs from the original `numpy.copysign
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.copysign.html>`_ in
the following aspects:

- ``where`` param is not supported.

Examples
--------
>>> np.copysign(1.3, -1)
-1.3
>>> 1/np.copysign(0, 1)
inf
>>> 1/np.copysign(0, -1)
-inf

>>> a = np.array([-1, 0, 1])
>>> np.copysign(a, -1.1)
array([-1., -0., -1.])
>>> np.copysign(a, np.arange(3)-1)
array([-1., 0., 1.])
"""
return _ufunc_helper(x1, x2, _npi.copysign, _np.copysign, _npi.copysign_scalar, _npi.rcopysign_scalar, out)
53 changes: 52 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices']
'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign']

# Return code for dispatching indexing function call
_NDARRAY_UNSUPPORTED_INDEXING = -1
Expand Down Expand Up @@ -3935,3 +3935,54 @@ def indices(dimensions, dtype=_np.int32, ctx=None):
"""
return _mx_nd_np.indices(dimensions=dimensions, dtype=dtype, ctx=ctx)
# pylint: enable=redefined-outer-name


@set_module('mxnet.numpy')
def copysign(x1, x2, out=None):
r"""copysign(x1, x2, out=None)

Change the sign of x1 to that of x2, element-wise.

If `x2` is a scalar, its sign will be copied to all elements of `x1`.

Parameters
----------
x1 : ndarray or scalar
Values to change the sign of.
x2 : ndarray or scalar
The sign of `x2` is copied to `x1`.
out : ndarray or None, optional
A location into which the result is stored. It must be of the
right shape and right type to hold the output. If not provided
or `None`,a freshly-allocated array is returned.

Returns
-------
out : ndarray or scalar
The values of `x1` with the sign of `x2`.
This is a scalar if both `x1` and `x2` are scalars.

Notes
-------
This function differs from the original `numpy.copysign
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.copysign.html>`_ in
the following aspects:

- ``where`` param is not supported.

Examples
--------
>>> np.copysign(1.3, -1)
-1.3
>>> 1/np.copysign(0, 1)
inf
>>> 1/np.copysign(0, -1)
-inf

>>> a = np.array([-1, 0, 1])
>>> np.copysign(a, -1.1)
array([-1., -0., -1.])
>>> np.copysign(a, np.arange(3)-1)
array([-1., 0., 1.])
"""
return _mx_nd_np.copysign(x1, x2, out=out)
36 changes: 35 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices']
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign']


def _num_outputs(sym):
Expand Down Expand Up @@ -2744,4 +2744,38 @@ def indices(dimensions, dtype=_np.int32, ctx=None):
# pylint: enable=redefined-outer-name


@set_module('mxnet.symbol.numpy')
def copysign(x1, x2, out=None):
r"""copysign(x1, x2, out=None)

Change the sign of x1 to that of x2, element-wise.

If `x2` is a scalar, its sign will be copied to all elements of `x1`.

Parameters
----------
x1 : _Symbol or scalar
Values to change the sign of.
x2 : _Symbol or scalar
The sign of `x2` is copied to `x1`.
out : _Symbol or None
Dummy parameter to keep the consistency with the ndarray counterpart.

Returns
-------
out : _Symbol
The values of `x1` with the sign of `x2`.
This is a scalar if both `x1` and `x2` are scalars.

Notes
-------
This function differs from the original `numpy.copysign
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.copysign.html>`_ in
the following aspects:

- ``where`` param is not supported.
"""
return _ufunc_helper(x1, x2, _npi.copysign, _np.copysign, _npi.copysign_scalar, _npi.rcopysign_scalar, out)


_set_np_symbol_class(_Symbol)
10 changes: 10 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,16 @@ MXNET_BINARY_MATH_OP(rdiv, math::id(b) / math::id(a));

MXNET_BINARY_MATH_OP(rdiv_grad, -math::id(b) / math::sqr(a));

MXNET_BINARY_MATH_OP(copysign, (a >= 0 && b >= 0) || (a < 0 && b < 0) ? a : -a);

MXNET_BINARY_MATH_OP(copysign_grad, (a >= 0 && b >= 0) || (a < 0 && b < 0) ? 1: -1);

MXNET_BINARY_MATH_OP(copysign_rgrad, 0);

MXNET_BINARY_MATH_OP(rcopysign, (b >= 0 && a >= 0) || (b < 0 && a < 0) ? b : -b);

MXNET_BINARY_MATH_OP(rcopysign_grad, 0);

struct mod : public mxnet_op::tunable {
template<typename DType>
MSHADOW_XINLINE static typename enable_if<!is_unsigned<DType>::value, DType>::type
Expand Down
36 changes: 36 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,26 @@ MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_power)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::power>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_broadcast_power"});

MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_copysign)
.describe(R"code()code" ADD_FILELINE)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::copysign>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_copysign"});

NNVM_REGISTER_OP(_backward_npi_copysign)
.set_num_inputs(3)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 1}};
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastBackwardUseIn<cpu, mshadow_op::copysign_grad,
mshadow_op::copysign_rgrad>);

MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_add_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, op::mshadow_op::plus>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_copy"});
Expand Down Expand Up @@ -108,5 +128,21 @@ MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rpower_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::rpower>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseOut{"_backward_rpower_scalar"});

MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_copysign_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::copysign>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_copysign_scalar"});

MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rcopysign_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::rcopysign>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_rcopysign_scalar"});

MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_backward_npi_copysign_scalar)
.set_attr<FCompute>("FCompute<cpu>",
BinaryScalarOp::Backward<cpu, mshadow_op::copysign_grad>);

MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_backward_npi_rcopysign_scalar)
.set_attr<FCompute>("FCompute<cpu>",
BinaryScalarOp::Backward<cpu, mshadow_op::rcopysign_grad>);

} // namespace op
} // namespace mxnet
21 changes: 21 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ NNVM_REGISTER_OP(_npi_mod)
NNVM_REGISTER_OP(_npi_power)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::power>);

NNVM_REGISTER_OP(_npi_copysign)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::copysign>);

NNVM_REGISTER_OP(_backward_npi_copysign)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseIn<gpu, mshadow_op::copysign_grad,
mshadow_op::copysign_rgrad>);

NNVM_REGISTER_OP(_npi_add_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, op::mshadow_op::plus>);

Expand All @@ -66,5 +73,19 @@ NNVM_REGISTER_OP(_npi_power_scalar)
NNVM_REGISTER_OP(_npi_rpower_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::rpower>);

NNVM_REGISTER_OP(_npi_copysign_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::copysign>);

NNVM_REGISTER_OP(_npi_rcopysign_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::rcopysign>);

NNVM_REGISTER_OP(_backward_npi_copysign_scalar)
.set_attr<FCompute>("FCompute<gpu>",
BinaryScalarOp::Backward<gpu, mshadow_op::copysign_grad>);

NNVM_REGISTER_OP(_backward_npi_rcopysign_scalar)
.set_attr<FCompute>("FCompute<gpu>",
BinaryScalarOp::Backward<gpu, mshadow_op::rcopysign_grad>);

} // namespace op
} // namespace mxnet
5 changes: 5 additions & 0 deletions src/operator/operator_tune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,11 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::elu); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rpower_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::power_rgrad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::copysign); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rcopysign); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::copysign_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::copysign_rgrad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rcopysign_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::xelu_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::gelu_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::prelu_grad); // NOLINT()
Expand Down
105 changes: 105 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1853,6 +1853,111 @@ def hybrid_forward(self, F, x):
assert_almost_equal(mx_ret.asnumpy(), np_ret, atol=1e-5, rtol=1e-4)


@with_seed()
@use_np
def test_np_copysign():
class TestCopysign(HybridBlock):
def __init__(self):
super(TestCopysign, self).__init__()

def hybrid_forward(self, F, a1, a2):
return F.np.copysign(a1, a2)

def get_grad(a1, a2):
sign = _np.logical_or(_np.logical_and(a1 < 0, a2 < 0),
_np.logical_and(a1 >= 0, a2 >= 0))
sign = 2 * sign.astype(int) - 1
sign = sign.reshape(-1, *a1.shape)
sign = _np.sum(sign, axis=0)
return sign, _np.zeros_like(a2)

def get_grad_left(a1, a2):
sign = _np.logical_or(_np.logical_and(a1 < 0, a2 < 0),
_np.logical_and(a1 >= 0, a2 >= 0))
sign = 2 * sign.astype(int) - 1
sign = sign.reshape(a1.shape)
return sign

def get_grad_right(a1, a2):
return _np.zeros_like(a2)

shapes = [
(),
(1),
(2, 1),
(3, 2, 1),
(4, 3, 2, 1),
(2, 4, 3, 2, 1)
]
types = ['float16', 'float32', 'float64', 'int8', 'int32', 'int64']
for a1shape in shapes:
for a2shape in shapes:
for hybridize in [True, False]:
for dtype in types:
test_copysign = TestCopysign()
if hybridize:
test_copysign.hybridize()
rtol = 1e-3
atol = 1e-5
a1_np = _np.array(_np.random.uniform(-1.0, 1.0, a1shape), dtype=dtype)
a2_np = _np.array(_np.random.uniform(-1.0, 1.0, a2shape), dtype=dtype)
a1 = np.array(a1_np, dtype=dtype)
a2 = np.array(a2_np, dtype=dtype)
a1.attach_grad()
a2.attach_grad()
expected_np = _np.copysign(a1_np, a2_np)
with mx.autograd.record():
mx_out = test_copysign(a1, a2)
assert mx_out.shape == expected_np.shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test if mx_out.dtype matches expected_np.dtype. I think we can later add a utility to match the shape + dtype.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ignore my previous comments. I think we'd better check the return types after we support arbitrary dtype combinations in deepnumpy. So it's okay to not check the return types now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that in the future we should do it, after more dtypes and casting is supported.

assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol)

# Test gradient
mx_out.backward()
a1_grad, a2_grad = get_grad(a1_np, a2_np)
assert_almost_equal(a1.grad.asnumpy(), a1_grad, rtol=rtol, atol=atol)
assert_almost_equal(a2.grad.asnumpy(), a2_grad, rtol=rtol, atol=atol)

# Test imperative once again
mx_out = np.copysign(a1, a2)
expected_np = _np.copysign(a1_np, a2_np)
assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol)

types = ['float16', 'float32', 'float64']
for x_shape in shapes:
for dtype in types:
# Test left
x_np = _np.array(_np.random.uniform(-2.0, 2.0, x_shape), dtype=dtype)
scalar = _np.random.uniform(-2.0, 2.0)
x = np.array(x_np, dtype=dtype)
x.attach_grad()
expected_np = _np.copysign(x_np, scalar)
with mx.autograd.record():
mx_out = np.copysign(x, scalar)
assert mx_out.shape == expected_np.shape
assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol)

# Test gradient
mx_out.backward()
x_grad = get_grad_left(x_np, scalar)
assert_almost_equal(x.grad.asnumpy(), x_grad, rtol=rtol, atol=atol)

# Test right
x_np = _np.array(_np.random.uniform(-2.0, 2.0, x_shape), dtype=dtype)
scalar = _np.random.uniform(-2.0, 2.0)
x = np.array(x_np, dtype=dtype)
x.attach_grad()
expected_np = _np.copysign(scalar, x_np)
with mx.autograd.record():
mx_out = np.copysign(scalar, x)
assert mx_out.shape == expected_np.shape
assert_almost_equal(mx_out.asnumpy(), expected_np, rtol=rtol, atol=atol)

# Test gradient
mx_out.backward()
x_grad = get_grad_right(scalar, x_np)
assert_almost_equal(x.grad.asnumpy(), x_grad, rtol=rtol, atol=atol)


if __name__ == '__main__':
import nose
nose.runmodule()