Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[numpy] operator around #16126

Merged
merged 1 commit into from
Sep 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'ravel', 'hanning', 'hamming', 'blackman', 'flip']
'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -2896,3 +2896,60 @@ def flip(m, axis=None, out=None):
return _npi.flip(m, axis, out=out)
else:
raise TypeError('type {} not supported'.format(str(type(m))))


@set_module('mxnet.ndarray.numpy')
def around(x, decimals=0, out=None, **kwargs):
r"""
around(x, decimals=0, out=None)

Evenly round to the given number of decimals.
Parameters
----------
x : ndarray or scalar
Input data.
decimals : int, optional
Number of decimal places to round to (default: 0). If
decimals is negative, it specifies the number of positions to
the left of the decimal point.
out : ndarray, optional
Alternative output array in which to place the result. It must have
the same shape and type as the expected output.

Returns
-------
rounded_array : ndarray or scalar
An array of the same type as `x`, containing the rounded values.
A reference to the result is returned.

Notes
-----
For values exactly halfway between rounded decimal values, NumPy
rounds to the nearest even value. Thus 1.5 and 2.5 round to 2.0,
-0.5 and 0.5 round to 0.0, etc.

This function differs from the original numpy.prod in the following aspects:

- Cannot cast type automatically. Dtype of `out` must be same as the expected one.
- Cannot support complex-valued number.

Examples
--------
>>> np.around([0.37, 1.64])
array([ 0., 2.])
>>> np.around([0.37, 1.64], decimals=1)
array([ 0.4, 1.6])
>>> np.around([.5, 1.5, 2.5, 3.5, 4.5]) # rounds to nearest even value
array([ 0., 2., 2., 4., 4.])
>>> np.around([1, 2, 3, 11], decimals=1) # ndarray of ints is returned
array([ 1, 2, 3, 11])
>>> np.around([1, 2, 3, 11], decimals=-1)
array([ 0, 0, 0, 10])
"""
from ...numpy import ndarray
if isinstance(x, numeric_types):
return _np.around(x, decimals, **kwargs)
elif isinstance(x, ndarray):
return _npi.around(x, decimals, out=out, **kwargs)
else:
raise TypeError('type {} not supported'.format(str(type(x))))
54 changes: 53 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices',
'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip']
'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around']

# Return code for dispatching indexing function call
_NDARRAY_UNSUPPORTED_INDEXING = -1
Expand Down Expand Up @@ -4429,3 +4429,55 @@ def flip(m, axis=None, out=None):
[3, 2]]])
"""
return _mx_nd_np.flip(m, axis, out=out)


@set_module('mxnet.numpy')
def around(x, decimals=0, out=None, **kwargs):
r"""
around(x, decimals=0, out=None)

Evenly round to the given number of decimals.

Parameters
----------
x : ndarray or scalar
Input data.
decimals : int, optional
Number of decimal places to round to (default: 0). If
decimals is negative, it specifies the number of positions to
the left of the decimal point.
out : ndarray, optional
Alternative output array in which to place the result. It must have
the same shape and type as the expected output.

Returns
-------
rounded_array : ndarray or scalar
An array of the same type as `x`, containing the rounded values.
A reference to the result is returned.

Notes
-----
For values exactly halfway between rounded decimal values, NumPy
rounds to the nearest even value. Thus 1.5 and 2.5 round to 2.0,
-0.5 and 0.5 round to 0.0, etc.

This function differs from the original numpy.prod in the following aspects:

- Cannot cast type automatically. Dtype of `out` must be same as the expected one.
- Cannot support complex-valued number.

Examples
--------
>>> np.around([0.37, 1.64])
array([ 0., 2.])
>>> np.around([0.37, 1.64], decimals=1)
array([ 0.4, 1.6])
>>> np.around([.5, 1.5, 2.5, 3.5, 4.5]) # rounds to nearest even value
array([ 0., 2., 2., 4., 4.])
>>> np.around([1, 2, 3, 11], decimals=1) # ndarray of ints is returned
array([ 1, 2, 3, 11])
>>> np.around([1, 2, 3, 11], decimals=-1)
array([ 0, 0, 0, 10])
"""
return _mx_nd_np.around(x, decimals, out=out, **kwargs)
45 changes: 44 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'ravel', 'hanning', 'hamming', 'blackman', 'flip']
'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around']


def _num_outputs(sym):
Expand Down Expand Up @@ -3129,4 +3129,47 @@ def flip(m, axis=None, out=None):
raise TypeError('type {} not supported'.format(str(type(m))))


@set_module('mxnet.symbol.numpy')
def around(x, decimals=0, out=None, **kwargs):
r"""
around(x, decimals=0, out=None)

Evenly round to the given number of decimals.
Parameters
----------
x : _Symbol or scalar
Input data.
decimals : int, optional
Number of decimal places to round to (default: 0). If
decimals is negative, it specifies the number of positions to
the left of the decimal point.
out : _Symbol, optional
Alternative output array in which to place the result. It must have
the same shape and type as the expected output.

Returns
-------
rounded_array : _Symbol or scalar
An array of the same type as `x`, containing the rounded values.
A reference to the result is returned.

Notes
-----
For values exactly halfway between rounded decimal values, NumPy
rounds to the nearest even value. Thus 1.5 and 2.5 round to 2.0,
-0.5 and 0.5 round to 0.0, etc.

This function differs from the original numpy.prod in the following aspects:

- Cannot cast type automatically. Dtype of `out` must be same as the expected one.
- Cannot support complex-valued number.
"""
if isinstance(x, numeric_types):
return _np.around(x, decimals, **kwargs)
elif isinstance(x, _Symbol):
return _npi.around(x, decimals, out=out, **kwargs)
else:
raise TypeError('type {} not supported'.format(str(type(x))))


_set_np_symbol_class(_Symbol)
34 changes: 34 additions & 0 deletions src/operator/numpy/np_elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -362,5 +362,39 @@ computed element-wise.
)code" ADD_FILELINE)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{ "_backward_arctanh" });

inline bool AroundOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);

TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));

CHECK(in_attrs->at(0) != mshadow::kFloat16) << "Do not support `float16` as input.\n";
return out_attrs->at(0) != -1;
}

DMLC_REGISTER_PARAMETER(AroundParam);

NNVM_REGISTER_OP(_npi_around)
.set_attr_parser(ParamParser<AroundParam>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"x"};
})
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", AroundOpType)
.set_attr<FCompute>("FCompute<cpu>", AroundOpForward<cpu>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}};
})
.add_argument("x", "NDArray-or-Symbol", "Input ndarray")
.add_arguments(AroundParam::__FIELDS__())
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);

} // namespace op
} // namespace mxnet
3 changes: 3 additions & 0 deletions src/operator/numpy/np_elemwise_unary_op_basic.cu
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,8 @@ MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_arccosh, mshadow_op::arccosh);

MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_arctanh, mshadow_op::arctanh);

NNVM_REGISTER_OP(_npi_around)
.set_attr<FCompute>("FCompute<gpu>", AroundOpForward<gpu>);

} // namespace op
} // namespace mxnet
83 changes: 83 additions & 0 deletions src/operator/tensor/elemwise_unary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,89 @@ struct ReshapeLikeParam : public dmlc::Parameter<ReshapeLikeParam> {
}
};

struct AroundParam : public dmlc::Parameter<AroundParam> {
int decimals;
DMLC_DECLARE_PARAMETER(AroundParam) {
DMLC_DECLARE_FIELD(decimals)
.set_default(0)
.describe("Number of decimal places to round to.");
}
};

template<int req>
struct around_forward {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data,
const int decimals) {
int d = 0;
DType temp = in_data[i];
DType roundtemp;
while (d != decimals) {
if (decimals > 0) {
d++;
temp *= 10;
} else {
d--;
temp /= 10;
}
}
roundtemp = (DType)round(static_cast<double>(temp));
// If temp is x.5 and roundtemp is odd number, decrease or increase roundtemp by 1.
// For example, in numpy, around(0.5) should be 0 but in c, round(0.5) is 1.
if (roundtemp - temp == 0.5 && (static_cast<int>(roundtemp)) % 2 != 0) {
roundtemp -= 1;
} else if (temp - roundtemp == 0.5 && (static_cast<int>(roundtemp)) % 2 != 0) {
roundtemp += 1;
}
while (d != 0) {
if (roundtemp == 0) {
break;
}
if (decimals > 0) {
d--;
roundtemp /= 10;
} else {
d++;
roundtemp *= 10;
}
}
KERNEL_ASSIGN(out_data[i], req, roundtemp);
}
};

template<typename xpu>
void AroundOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob& in_data = inputs[0];
const TBlob& out_data = outputs[0];
const AroundParam& param = nnvm::get<AroundParam>(attrs.parsed);
using namespace mxnet_op;
// if the type is uint8, int8, int32 or int64 and decimals is greater than 0
// we simply return the number back.
if (in_data.type_flag_ >= mshadow::kUint8 && in_data.type_flag_ <= mshadow::kInt64 \
&& param.decimals > 0) {
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
Kernel<mshadow_op::identity_with_cast, xpu>::Launch(
s, out_data.Size(), out_data.dptr<DType>(), in_data.dptr<DType>());
});
} else {
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
Kernel<around_forward<req_type>, xpu>::Launch(
s, out_data.Size(), out_data.dptr<DType>(), in_data.dptr<DType>(),
param.decimals);
});
});
}
}

/*! \brief Unary compute */
#define MXNET_OPERATOR_REGISTER_UNARY(__name$) \
NNVM_REGISTER_OP(__name$) \
Expand Down
2 changes: 1 addition & 1 deletion tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2346,7 +2346,7 @@ def test_arange_like_dtype():
out = mod.forward(is_train=False)
for v in out:
assert v.dtype == t


if __name__ == '__main__':
import nose
Expand Down
32 changes: 32 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2348,6 +2348,38 @@ def hybrid_forward(self, F, x):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)


@with_seed()
@use_np
def test_np_around():
class TestAround(HybridBlock):
def __init__(self, decimals):
super(TestAround, self).__init__()
self.decimals = decimals

def hybrid_forward(self, F, x):
return F.np.around(x, self.decimals)

shapes = [(), (1, 2, 3), (1, 0)]
types = ['int32', 'int64', 'float32', 'float64']
for hybridize in [True, False]:
for oneType in types:
rtol, atol = 1e-3, 1e-5
for shape in shapes:
for d in range(-5, 6):
test_around = TestAround(d)
if hybridize:
test_around.hybridize()
x = rand_ndarray(shape, dtype=oneType).as_np_ndarray()
np_out = _np.around(x.asnumpy(), d)
mx_out = test_around(x)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)

mx_out = np.around(x, d)
np_out = _np.around(x.asnumpy(), d)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)


if __name__ == '__main__':
import nose
nose.runmodule()