Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Implements ldexp. (#15845)
Browse files Browse the repository at this point in the history
Remove spaces.

Change tests.

Reorganize files.

Change styles.

Add spaces
  • Loading branch information
ckt624 authored and reminisce committed Oct 5, 2019
1 parent 916fbf2 commit adf288c
Show file tree
Hide file tree
Showing 8 changed files with 242 additions and 4 deletions.
42 changes: 40 additions & 2 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
'unique', 'lcm', 'tril', 'identity', 'take']
'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -3447,7 +3447,7 @@ def hypot(x1, x2, out=None):
Notes
-----
This function differs from the original numpy.arange in the following aspects:
- Only support float16, float32 and float64.
- Only support float16, float32 and float64.
Examples
--------
Expand All @@ -3464,3 +3464,41 @@ def hypot(x1, x2, out=None):
[ 5., 5., 5.]])
"""
return _ufunc_helper(x1, x2, _npi.hypot, _np.hypot, _npi.hypot_scalar, None, out)


@set_module('mxnet.ndarray.numpy')
def ldexp(x1, x2, out=None):
"""
Returns x1 * 2**x2, element-wise.
The mantissas `x1` and twos exponents `x2` are used to construct
floating point numbers ``x1 * 2**x2``.
Parameters
----------
x1 : ndarray or scalar
Array of multipliers.
x2 : ndarray or scalar, int
Array of twos exponents.
out : ndarray, optional
A location into which the result is stored. If provided, it must have
a shape that the inputs broadcast to. If not, a freshly-allocated array is returned.
Returns
-------
y : ndarray or scalar
The result of ``x1 * 2**x2``.
This is a scalar if both `x1` and `x2` are scalars.
Notes
-----
Complex dtypes are not supported, they will raise a TypeError.
Different from numpy, we allow x2 to be float besides int.
`ldexp` is useful as the inverse of `frexp`, if used by itself it is
more clear to simply use the expression ``x1 * 2**x2``.
Examples
--------
>>> np.ldexp(5, np.arange(4))
array([ 5., 10., 20., 40.])
"""
return _ufunc_helper(x1, x2, _npi.ldexp, _np.ldexp, _npi.ldexp_scalar, _npi.rldexp_scalar, out)
40 changes: 39 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices',
'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot',
'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take']
'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp']

# Return code for dispatching indexing function call
_NDARRAY_UNSUPPORTED_INDEXING = -1
Expand Down Expand Up @@ -4980,3 +4980,41 @@ def hypot(x1, x2, out=None):
[ 5., 5., 5.]])
"""
return _mx_nd_np.hypot(x1, x2, out=out)


@set_module('mxnet.numpy')
def ldexp(x1, x2, out=None):
"""
Returns x1 * 2**x2, element-wise.
The mantissas `x1` and twos exponents `x2` are used to construct
floating point numbers ``x1 * 2**x2``.
Parameters
----------
x1 : ndarray or scalar
Array of multipliers.
x2 : ndarray or scalar, int
Array of twos exponents.
out : ndarray, optional
A location into which the result is stored. If provided, it must have
a shape that the inputs broadcast to. If not, a freshly-allocated array is returned.
Returns
-------
y : ndarray or scalar
The result of ``x1 * 2**x2``.
This is a scalar if both `x1` and `x2` are scalars.
Notes
-----
Complex dtypes are not supported, they will raise a TypeError.
Different from numpy, we allow x2 to be float besides int.
`ldexp` is useful as the inverse of `frexp`, if used by itself it is
more clear to simply use the expression ``x1 * 2**x2``.
Examples
--------
>>> np.ldexp(5, np.arange(4))
array([ 5., 10., 20., 40.])
"""
return _mx_nd_np.ldexp(x1, x2, out)
31 changes: 30 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
'unique', 'lcm', 'tril', 'identity', 'take']
'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp']


def _num_outputs(sym):
Expand Down Expand Up @@ -3552,4 +3552,33 @@ def unique(ar, return_index=False, return_inverse=False, return_counts=False, ax
return _npi.unique(ar, return_index, return_inverse, return_counts, axis)


@set_module('mxnet.symbol.numpy')
def ldexp(x1, x2, out=None):
"""
ldexp(x1, x2, out=None)
Returns x1 * 2**x2, element-wise.
The mantissas `x1` and twos exponents `x2` are used to construct
floating point numbers ``x1 * 2**x2``.
Parameters
----------
x1 : _Symbol
Array of multipliers.
x2 : _Symbol
Array of twos exponents.
out : _Symbol or None
Dummy parameter to keep the consistency with the ndarray counterpart.
Returns
-------
y : _Symbol
The result of ``x1 * 2**x2``.
Notes
-----
Complex dtypes are not supported, they will raise a TypeError.
Different from numpy, we allow x2 to be float besides int.
`ldexp` is useful as the inverse of `frexp`, if used by itself it is
more clear to simply use the expression ``x1 * 2**x2``.
"""
return _ufunc_helper(x1, x2, _npi.ldexp, _np.ldexp, _npi.ldexp_scalar, _npi.rldexp_scalar, out)


_set_np_symbol_class(_Symbol)
11 changes: 11 additions & 0 deletions src/operator/mshadow_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,17 @@ MXNET_UNARY_MATH_OP(reciprocal_cube_root, 1.0f / math::cbrt(a));

MXNET_UNARY_MATH_OP(reciprocal_cube_root_grad, -1.0f / (3.0f * math::cbrt(a) * math::id(a)));

/*! \brief used for generate element of ldexp */
MXNET_BINARY_MATH_OP(ldexp, math::id(a) * math::pow(2.0f, b));

MXNET_BINARY_MATH_OP(ldexp_grad, math::pow(2.0f, b));

MXNET_BINARY_MATH_OP(ldexp_rgrad, math::id(a) * math::pow(2.0f, b) * math::log(2.0f));

MXNET_BINARY_MATH_OP(rldexp, math::id(b) * math::pow(2.0f, a)); // swap a and b if a is scalar.

MXNET_BINARY_MATH_OP(rldexp_grad, math::id(b) * math::pow(2.0f, a) * math::log(2.0f));

/*! \brief used for generate element of round */
MXNET_SIMPLE_UNARY_MATH_OP(round);

Expand Down
37 changes: 37 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -296,5 +296,42 @@ NNVM_REGISTER_OP(_npi_lcm_scalar)
.add_argument("scalar", "int", "scalar input")
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::lcm>);

MXNET_OPERATOR_REGISTER_BINARY_BROADCAST(_npi_ldexp)
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastCompute<cpu, mshadow_op::ldexp>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_ldexp"});

MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_ldexp_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::ldexp>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_ldexp_scalar"});

MXNET_OPERATOR_REGISTER_NP_BINARY_SCALAR(_npi_rldexp_scalar)
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::rldexp>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{"_backward_npi_rldexp_scalar"});

NNVM_REGISTER_OP(_backward_npi_ldexp)
.set_num_inputs(3)
.set_num_outputs(2)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 1}};
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastBackwardUseIn<cpu, mshadow_op::ldexp_grad,
mshadow_op::ldexp_rgrad>);

MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_ldexp_scalar)
.add_argument("scalar", "float", "scalar value")
.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); })
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Backward<cpu, mshadow_op::ldexp_grad>);

MXNET_OPERATOR_REGISTER_BINARY(_backward_npi_rldexp_scalar)
.add_argument("scalar", "float", "scalar value")
.set_attr_parser([](NodeAttrs *attrs) { attrs->parsed = std::stod(attrs->dict["scalar"]); })
.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Backward<cpu, mshadow_op::rldexp_grad>);

} // namespace op
} // namespace mxnet
19 changes: 19 additions & 0 deletions src/operator/numpy/np_elemwise_broadcast_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -118,5 +118,24 @@ NNVM_REGISTER_OP(_backward_npi_rarctan2_scalar)
NNVM_REGISTER_OP(_npi_lcm_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::lcm>);

NNVM_REGISTER_OP(_npi_ldexp)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastCompute<gpu, mshadow_op::ldexp>);

NNVM_REGISTER_OP(_npi_ldexp_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::ldexp>);

NNVM_REGISTER_OP(_npi_rldexp_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Compute<gpu, mshadow_op::rldexp>);

NNVM_REGISTER_OP(_backward_npi_ldexp)
.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastBackwardUseIn<gpu, mshadow_op::ldexp_grad,
mshadow_op::ldexp_rgrad>);

NNVM_REGISTER_OP(_backward_npi_ldexp_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Backward<gpu, mshadow_op::ldexp_grad>);

NNVM_REGISTER_OP(_backward_npi_rldexp_scalar)
.set_attr<FCompute>("FCompute<gpu>", BinaryScalarOp::Backward<gpu, mshadow_op::rldexp_grad>);

} // namespace op
} // namespace mxnet
5 changes: 5 additions & 0 deletions src/operator/operator_tune.cc
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,11 @@ IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::lcm); // NOLINT()
IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::mxnet_op::set_to_int<0>); // NOLINT()
IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::mxnet_op::set_to_int<1>); // NOLINT()
IMPLEMENT_BLANK_WORKLOAD_FWD(mxnet::op::PopulateFullIdxRspKernel); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::ldexp); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::rldexp); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ldexp_grad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::ldexp_rgrad); // NOLINT()
IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::rldexp_grad); // NOLINT()
/*!
* \brief Tuner objects, *not* automatically generated
*/
Expand Down
61 changes: 61 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,67 @@ def test_np_dot():
assert False


@with_seed()
@use_np
def test_np_ldexp():
class TestLdexp(HybridBlock):
def __init__(self):
super(TestLdexp, self).__init__()

def hybrid_forward(self, F, x1, x2):
return F.np.ldexp(x1, x2)

def _np_ldexp(x1, x2):
return x1 * _np.power(2.0, x2)

def dldx(x1, x2):
grad_a = _np.power(2.0, x2)
grad_b = _np_ldexp(x1, x2) * _np.log(2.0)
if len(x1) == 1:
grad_a = _np.sum(grad_a)
if len(x2) == 1:
grad_b = _np.sum(grad_b)
return [grad_a, grad_b]

shapes = [
((3, 1), (3, 1)),
((3, 1, 2), (3, 1, 2)),
((1, ),(1, )),
((1, ), (2, )),
((3, ), (1, )),
((3, 0), (3, 0)), # zero-size shape
((0, 1), (0, 1)), # zero-size shape
((2, 0, 2), (2, 0, 2)), # zero-size shape
]

for hybridize in [True, False]:
for shape1, shape2 in shapes:
for dtype in [_np.float16, _np.float32, _np.float64]:
test_ldexp = TestLdexp()
if hybridize:
test_ldexp.hybridize()
x1 = rand_ndarray(shape=shape1, dtype=dtype).as_np_ndarray()
x1.attach_grad()
x2 = rand_ndarray(shape=shape2, dtype=dtype).as_np_ndarray()
x2.attach_grad()

np_out = _np_ldexp(x1.asnumpy(), x2.asnumpy())
with mx.autograd.record():
mx_out = test_ldexp(x1, x2)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-1, atol=1e-1)

mx_out.backward()
np_backward = dldx(x1.asnumpy(), x2.asnumpy())
assert_almost_equal(x1.grad.asnumpy(), np_backward[0], atol=1e-1, rtol=1e-1)
assert_almost_equal(x2.grad.asnumpy(), np_backward[1], atol=1e-1, rtol=1e-1)

# Test imperative once again
mx_out = np.ldexp(x1, x2)
np_out = _np_ldexp(x1.asnumpy(), x2.asnumpy())
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-1, atol=1e-1)


@with_seed()
@use_np
def test_np_sum():
Expand Down

0 comments on commit adf288c

Please sign in to comment.