Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[numpy] add op round (#17175)
Browse files Browse the repository at this point in the history
* add round

* sanity

* space
  • Loading branch information
Yiyan66 authored and haojin2 committed Dec 26, 2019
1 parent 07913f9 commit d26dd15
Show file tree
Hide file tree
Showing 6 changed files with 103 additions and 15 deletions.
27 changes: 23 additions & 4 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@
'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'vsplit', 'concatenate', 'append',
'stack', 'vstack', 'column_stack', 'dstack', 'average', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip',
'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index', 'hanning', 'hamming',
'blackman', 'flip', 'around', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less',
'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory',
'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where', 'bincount']
'blackman', 'flip', 'around', 'round', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad',
'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal',
'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero',
'shares_memory', 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where', 'bincount']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -4737,6 +4737,25 @@ def around(x, decimals=0, out=None, **kwargs):
raise TypeError('type {} not supported'.format(str(type(x))))


@set_module('mxnet.ndarray.numpy')
def round(x, decimals=0, out=None, **kwargs):
r"""
round_(a, decimals=0, out=None)
Round an array to the given number of decimals.
See Also
--------
around : equivalent function; see for details.
"""
from ...numpy import ndarray
if isinstance(x, numeric_types):
return _np.around(x, decimals, **kwargs)
elif isinstance(x, ndarray):
return _npi.around(x, decimals, out=out, **kwargs)
else:
raise TypeError('type {} not supported'.format(str(type(x))))


@set_module('mxnet.ndarray.numpy')
@wrap_np_binary_func
def arctan2(x1, x2, out=None, **kwargs):
Expand Down
23 changes: 18 additions & 5 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@
'tensordot', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'array_split',
'split', 'vsplit', 'concatenate', 'stack', 'vstack', 'column_stack', 'dstack', 'average', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign',
'ravel', 'unravel_index', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot',
'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take',
'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal',
'ravel', 'unravel_index', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'round', 'arctan2',
'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity',
'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal',
'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory',
'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where', 'bincount']

Expand Down Expand Up @@ -1558,13 +1558,13 @@ def norm(self, *args, **kwargs):
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute norm')

def round(self, *args, **kwargs):
def round(self, decimals=0, out=None, **kwargs): # pylint: disable=arguments-differ
"""Convenience fluent method for :py:func:`round`.
The arguments are the same as for :py:func:`round`, with
this array as data.
"""
raise NotImplementedError
return round(self, decimals=decimals, out=out, **kwargs)

def rint(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`rint`.
Expand Down Expand Up @@ -6456,6 +6456,19 @@ def around(x, decimals=0, out=None, **kwargs):
return _mx_nd_np.around(x, decimals, out=out, **kwargs)


@set_module('mxnet.numpy')
def round(x, decimals=0, out=None, **kwargs):
r"""
round_(a, decimals=0, out=None)
Round an array to the given number of decimals.
See Also
--------
around : equivalent function; see for details.
"""
return _mx_nd_np.around(x, decimals, out=out, **kwargs)


@set_module('mxnet.numpy')
@wrap_np_binary_func
def arctan2(x1, x2, out=None, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'argmin',
'argmax',
'around',
'round',
'argsort',
'append',
'broadcast_arrays',
Expand Down
30 changes: 24 additions & 6 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@
'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'vsplit', 'concatenate', 'append',
'stack', 'vstack', 'column_stack', 'dstack', 'average', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip',
'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index', 'hanning', 'hamming',
'blackman', 'flip', 'around', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less',
'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory',
'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where', 'bincount']
'blackman', 'flip', 'around', 'round', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad',
'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal',
'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide',
'shares_memory', 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where', 'bincount']


@set_module('mxnet.symbol.numpy')
Expand Down Expand Up @@ -665,13 +665,13 @@ def norm(self, *args, **kwargs):
"""
raise AttributeError('_Symbol object has no attribute norm')

def round(self, *args, **kwargs):
def round(self, decimals=0, out=None, **kwargs): # pylint: disable=arguments-differ
"""Convenience fluent method for :py:func:`round`.
The arguments are the same as for :py:func:`round`, with
this array as data.
"""
raise NotImplementedError
return round(self, decimals=decimals, out=out, **kwargs)

def rint(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`rint`.
Expand Down Expand Up @@ -4524,6 +4524,24 @@ def around(x, decimals=0, out=None, **kwargs):
raise TypeError('type {} not supported'.format(str(type(x))))


@set_module('mxnet.symbol.numpy')
def round(x, decimals=0, out=None, **kwargs):
r"""
round_(a, decimals=0, out=None)
Round an array to the given number of decimals.
See Also
--------
around : equivalent function; see for details.
"""
if isinstance(x, numeric_types):
return _np.around(x, decimals, **kwargs)
elif isinstance(x, _Symbol):
return _npi.around(x, decimals, out=out, **kwargs)
else:
raise TypeError('type {} not supported'.format(str(type(x))))


@set_module('mxnet.symbol.numpy')
@wrap_np_binary_func
def arctan2(x1, x2, out=None, **kwargs):
Expand Down
5 changes: 5 additions & 0 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,10 @@ def _add_workload_around():
OpArgMngr.add_workload('around', np.array([1.56, 72.54, 6.35, 3.25]), decimals=1)


def _add_workload_round():
OpArgMngr.add_workload('round', np.array([1.56, 72.54, 6.35, 3.25]), decimals=1)


def _add_workload_argsort():
for dtype in [np.int32, np.float32]:
a = np.arange(101, dtype=dtype)
Expand Down Expand Up @@ -1442,6 +1446,7 @@ def _prepare_workloads():
_add_workload_argmin()
_add_workload_argmax()
_add_workload_around()
_add_workload_round()
_add_workload_argsort()
_add_workload_append()
_add_workload_bincount()
Expand Down
32 changes: 32 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4435,6 +4435,38 @@ def hybrid_forward(self, F, x):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)


@with_seed()
@use_np
def test_np_round():
class TestRound(HybridBlock):
def __init__(self, decimals):
super(TestRound, self).__init__()
self.decimals = decimals

def hybrid_forward(self, F, x):
return F.np.round(x, self.decimals)

shapes = [(), (1, 2, 3), (1, 0)]
types = ['int32', 'int64', 'float32', 'float64']
for hybridize in [True, False]:
for oneType in types:
rtol, atol = 1e-3, 1e-5
for shape in shapes:
for d in range(-5, 6):
test_round = TestRound(d)
if hybridize:
test_round.hybridize()
x = rand_ndarray(shape, dtype=oneType).as_np_ndarray()
np_out = _np.round(x.asnumpy(), d)
mx_out = test_round(x)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)

mx_out = np.round(x, d)
np_out = _np.round(x.asnumpy(), d)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)


@with_seed()
@use_np
def test_np_nonzero():
Expand Down

0 comments on commit d26dd15

Please sign in to comment.