Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
numpy operator around
Browse files Browse the repository at this point in the history
* change the name of argument

* add doc in three files and fix some bug

* change the data type in .h and add test function

    cancel optimization when abs(temp) < 0.5
    modify test on cpu and add test on gpu
    do not support float16
    edit testcase on gpu and add 'Do not support float16 on doc'

* edit doc: support scalar

* adjust the format

* add license

* fix format error

* delete gpu test

* move around to np_elemwise_unary_op_basic
  • Loading branch information
Ying committed Sep 9, 2019
1 parent 1c67928 commit 7a3e387
Show file tree
Hide file tree
Showing 8 changed files with 325 additions and 4 deletions.
59 changes: 58 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var']
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'around']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -2363,3 +2363,60 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint:
0.2025
"""
return _npi.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)


@set_module('mxnet.ndarray.numpy')
def around(x, decimals=0, out=None, **kwargs):
r"""
around(x, decimals=0, out=None)
Evenly round to the given number of decimals.
Parameters
----------
x : ndarray or scalar
Input data.
decimals : int, optional
Number of decimal places to round to (default: 0). If
decimals is negative, it specifies the number of positions to
the left of the decimal point.
out : ndarray, optional
Alternative output array in which to place the result. It must have
the same shape and type as the expected output.
Returns
-------
rounded_array : ndarray or scalar
An array of the same type as `x`, containing the rounded values.
A reference to the result is returned.
Notes
-----
For values exactly halfway between rounded decimal values, NumPy
rounds to the nearest even value. Thus 1.5 and 2.5 round to 2.0,
-0.5 and 0.5 round to 0.0, etc.
This function differs from the original numpy.prod in the following aspects:
- Cannot cast type automatically. Dtype of `out` must be same as the expected one.
- Cannot support complex-valued number.
Examples
--------
>>> np.around([0.37, 1.64])
array([ 0., 2.])
>>> np.around([0.37, 1.64], decimals=1)
array([ 0.4, 1.6])
>>> np.around([.5, 1.5, 2.5, 3.5, 4.5]) # rounds to nearest even value
array([ 0., 2., 2., 4., 4.])
>>> np.around([1, 2, 3, 11], decimals=1) # ndarray of ints is returned
array([ 1, 2, 3, 11])
>>> np.around([1, 2, 3, 11], decimals=-1)
array([ 0, 0, 0, 10])
"""
from ...numpy import ndarray
if isinstance(x, numeric_types):
return _np.around(x, decimals, **kwargs)
elif isinstance(x, ndarray):
return _npi.around(x, decimals, out=out, **kwargs)
else:
raise TypeError('type {} not supported'.format(str(type(x))))
53 changes: 52 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var']
'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'around']

# Return code for dispatching indexing function call
_NDARRAY_UNSUPPORTED_INDEXING = -1
Expand Down Expand Up @@ -3808,3 +3808,54 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=None):
0.2025
"""
return _npi.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)


def around(x, decimals=0, out=None, **kwargs):
r"""
around(x, decimals=0, out=None)
Evenly round to the given number of decimals.
Parameters
----------
x : ndarray or scalar
Input data.
decimals : int, optional
Number of decimal places to round to (default: 0). If
decimals is negative, it specifies the number of positions to
the left of the decimal point.
out : ndarray, optional
Alternative output array in which to place the result. It must have
the same shape and type as the expected output.
Returns
-------
rounded_array : ndarray or scalar
An array of the same type as `x`, containing the rounded values.
A reference to the result is returned.
Notes
-----
For values exactly halfway between rounded decimal values, NumPy
rounds to the nearest even value. Thus 1.5 and 2.5 round to 2.0,
-0.5 and 0.5 round to 0.0, etc.
This function differs from the original numpy.prod in the following aspects:
- Cannot cast type automatically. Dtype of `out` must be same as the expected one.
- Cannot support complex-valued number.
Examples
--------
>>> np.around([0.37, 1.64])
array([ 0., 2.])
>>> np.around([0.37, 1.64], decimals=1)
array([ 0.4, 1.6])
>>> np.around([.5, 1.5, 2.5, 3.5, 4.5]) # rounds to nearest even value
array([ 0., 2., 2., 4., 4.])
>>> np.around([1, 2, 3, 11], decimals=1) # ndarray of ints is returned
array([ 1, 2, 3, 11])
>>> np.around([1, 2, 3, 11], decimals=-1)
array([ 0, 0, 0, 10])
"""
return _mx_nd_np.around(x, decimals, out=out, **kwargs)
45 changes: 44 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var']
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'around']


def _num_outputs(sym):
Expand Down Expand Up @@ -2678,4 +2678,47 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint:
return _npi.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)


@set_module('mxnet.symbol.numpy')
def around(x, decimals=0, out=None, **kwargs):
r"""
around(x, decimals=0, out=None)
Evenly round to the given number of decimals.
Parameters
----------
x : _Symbol or scalar
Input data.
decimals : int, optional
Number of decimal places to round to (default: 0). If
decimals is negative, it specifies the number of positions to
the left of the decimal point.
out : _Symbol, optional
Alternative output array in which to place the result. It must have
the same shape and type as the expected output.
Returns
-------
rounded_array : _Symbol or scalar
An array of the same type as `x`, containing the rounded values.
A reference to the result is returned.
Notes
-----
For values exactly halfway between rounded decimal values, NumPy
rounds to the nearest even value. Thus 1.5 and 2.5 round to 2.0,
-0.5 and 0.5 round to 0.0, etc.
This function differs from the original numpy.prod in the following aspects:
- Cannot cast type automatically. Dtype of `out` must be same as the expected one.
- Cannot support complex-valued number.
"""
if isinstance(x, numeric_types):
return _np.around(x, decimals, **kwargs)
elif isinstance(x, _Symbol):
return _npi.around(x, decimals, out=out, **kwargs)
else:
raise TypeError('type {} not supported'.format(str(type(x))))


_set_np_symbol_class(_Symbol)
38 changes: 38 additions & 0 deletions src/operator/numpy/np_elemwise_unary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -362,5 +362,43 @@ computed element-wise.
)code" ADD_FILELINE)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseIn{ "_backward_arctanh" });

inline bool AroundOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);

TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));

if (in_attrs->at(0) == mshadow::kFloat16) {
std::ostringstream os;
os << "Do not support `float16` as input.\n";
throw ::mxnet::op::InferTypeError(os.str(), 0);
}
return out_attrs->at(0) != -1;
}

DMLC_REGISTER_PARAMETER(AroundParam);

NNVM_REGISTER_OP(_npi_around)
.set_attr_parser(ParamParser<AroundParam>)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"x"};
})
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", AroundOpType)
.set_attr<FCompute>("FCompute<cpu>", AroundOpForward<cpu>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}};
})
.add_argument("x", "NDArray-or-Symbol", "Input ndarray")
.add_arguments(AroundParam::__FIELDS__())
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes);

} // namespace op
} // namespace mxnet
3 changes: 3 additions & 0 deletions src/operator/numpy/np_elemwise_unary_op_basic.cu
Original file line number Diff line number Diff line change
Expand Up @@ -106,5 +106,8 @@ MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_arccosh, mshadow_op::arccosh);

MXNET_OPERATOR_REGISTER_NUMPY_UNARY_GPU(_npi_arctanh, mshadow_op::arctanh);

NNVM_REGISTER_OP(_npi_around)
.set_attr<FCompute>("FCompute<gpu>", AroundOpForward<gpu>);

} // namespace op
} // namespace mxnet
95 changes: 95 additions & 0 deletions src/operator/tensor/elemwise_unary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,101 @@ struct ReshapeLikeParam : public dmlc::Parameter<ReshapeLikeParam> {
}
};

struct AroundParam : public dmlc::Parameter<AroundParam> {
int decimals;
DMLC_DECLARE_PARAMETER(AroundParam) {
DMLC_DECLARE_FIELD(decimals)
.set_default(0)
.describe("Number of decimal places to round to.");
}
};

template<int req>
struct around_forwardint{
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data,
const int decimals) {
KERNEL_ASSIGN(out_data[i], req, in_data[i]);
}
};

template<int req>
struct around_forward {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, DType* out_data, const DType* in_data,
const int decimals) {
int d = 0;
DType temp = in_data[i];
DType roundtemp;
while (d != decimals) {
if (decimals > 0) {
d++;
temp *= 10;
} else {
d--;
temp /= 10;
}
}
roundtemp = (DType)round(static_cast<double>(temp));
// If temp is x.5 and roundtemp is odd number, decrease or increase roundtemp by 1.
// For example, in numpy, around(0.5) should be 0 but in c, round(0.5) is 1.
if (roundtemp - temp == 0.5 && (static_cast<int>(roundtemp)) % 2 != 0) {
roundtemp -= 1;
} else if (temp - roundtemp == 0.5 && (static_cast<int>(roundtemp)) % 2 != 0) {
roundtemp += 1;
}
while (d != 0) {
if (roundtemp == 0) {
break;
}
if (decimals > 0) {
d--;
roundtemp /= 10;
} else {
d++;
roundtemp *= 10;
}
}
KERNEL_ASSIGN(out_data[i], req, roundtemp);
}
};

template<typename xpu>
void AroundOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob& in_data = inputs[0];
const TBlob& out_data = outputs[0];
const AroundParam& param = nnvm::get<AroundParam>(attrs.parsed);
using namespace mxnet_op;
// if the type is uint8, int8, int32 or int64 and decimals is greater than 0
// we simply return the number back.
if (in_data.type_flag_ >= mshadow::kUint8 && in_data.type_flag_ <= mshadow::kInt64 \
&& param.decimals > 0) {
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
Kernel<around_forwardint<req_type>, xpu>::Launch(
s, out_data.Size(), out_data.dptr<DType>(), in_data.dptr<DType>(),
param.decimals);
});
});
} else {
MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
Kernel<around_forward<req_type>, xpu>::Launch(
s, out_data.Size(), out_data.dptr<DType>(), in_data.dptr<DType>(),
param.decimals);
});
});
}
}

/*! \brief Unary compute */
#define MXNET_OPERATOR_REGISTER_UNARY(__name$) \
NNVM_REGISTER_OP(__name$) \
Expand Down
2 changes: 1 addition & 1 deletion tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -2345,7 +2345,7 @@ def test_arange_like_dtype():
out = mod.forward(is_train=False)
for v in out:
assert v.dtype == t


if __name__ == '__main__':
import nose
Expand Down
34 changes: 34 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1647,6 +1647,40 @@ def hybrid_forward(self, F, a):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


@with_seed()
@use_np
def test_np_around():
class TestAround(HybridBlock):
def __init__(self, decimals):
super(TestAround, self).__init__()
# necessary initializations
self.decimals = decimals

def hybrid_forward(self, F, x):
return F.np.around(x, self.decimals)

shapes = [(), (1,), (1, 1), (1, 2, 3), (1, 0), (3, 0, 2)] # test_shapes, remember to include zero-dim shape and zero-size shapes
types = ['int32', 'int64', 'float32', 'double']
for hybridize in [True, False]:
for oneType in types:
rtol=1e-3
atol=1e-5
for shape in shapes:
for d in range(-10, 11):
test_around = TestAround(d)
if hybridize:
test_around.hybridize()
x = rand_ndarray(shape, dtype=oneType).as_np_ndarray()
np_out = _np.around(x.asnumpy(), d)
mx_out = test_around(x)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)

mx_out = np.around(x, d)
np_out = _np.around(x.asnumpy(), d)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 7a3e387

Please sign in to comment.