diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 02e42145fb18..e380b4937168 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -37,10 +37,10 @@ 'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'vsplit', 'concatenate', 'append', 'stack', 'vstack', 'column_stack', 'dstack', 'average', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index', 'hanning', 'hamming', - 'blackman', 'flip', 'around', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', - 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', - 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', - 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where', 'bincount'] + 'blackman', 'flip', 'around', 'round', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', + 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', + 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', + 'shares_memory', 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where', 'bincount'] @set_module('mxnet.ndarray.numpy') @@ -4737,6 +4737,25 @@ def around(x, decimals=0, out=None, **kwargs): raise TypeError('type {} not supported'.format(str(type(x)))) +@set_module('mxnet.ndarray.numpy') +def round(x, decimals=0, out=None, **kwargs): + r""" + round_(a, decimals=0, out=None) + Round an array to the given number of decimals. + + See Also + -------- + around : equivalent function; see for details. + """ + from ...numpy import ndarray + if isinstance(x, numeric_types): + return _np.around(x, decimals, **kwargs) + elif isinstance(x, ndarray): + return _npi.around(x, decimals, out=out, **kwargs) + else: + raise TypeError('type {} not supported'.format(str(type(x)))) + + @set_module('mxnet.ndarray.numpy') @wrap_np_binary_func def arctan2(x1, x2, out=None, **kwargs): diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index a1b0e016445a..22094a1621d2 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -55,9 +55,9 @@ 'tensordot', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'vsplit', 'concatenate', 'stack', 'vstack', 'column_stack', 'dstack', 'average', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', - 'ravel', 'unravel_index', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot', - 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', - 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', + 'ravel', 'unravel_index', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'round', 'arctan2', + 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', + 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where', 'bincount'] @@ -1558,13 +1558,13 @@ def norm(self, *args, **kwargs): """ raise AttributeError('mxnet.numpy.ndarray object has no attribute norm') - def round(self, *args, **kwargs): + def round(self, decimals=0, out=None, **kwargs): # pylint: disable=arguments-differ """Convenience fluent method for :py:func:`round`. The arguments are the same as for :py:func:`round`, with this array as data. """ - raise NotImplementedError + return round(self, decimals=decimals, out=out, **kwargs) def rint(self, *args, **kwargs): """Convenience fluent method for :py:func:`rint`. @@ -6456,6 +6456,19 @@ def around(x, decimals=0, out=None, **kwargs): return _mx_nd_np.around(x, decimals, out=out, **kwargs) +@set_module('mxnet.numpy') +def round(x, decimals=0, out=None, **kwargs): + r""" + round_(a, decimals=0, out=None) + Round an array to the given number of decimals. + + See Also + -------- + around : equivalent function; see for details. + """ + return _mx_nd_np.around(x, decimals, out=out, **kwargs) + + @set_module('mxnet.numpy') @wrap_np_binary_func def arctan2(x1, x2, out=None, **kwargs): diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index c7e9dd1398eb..9aa755fb436e 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -86,6 +86,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'argmin', 'argmax', 'around', + 'round', 'argsort', 'append', 'broadcast_arrays', diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 3ee385660715..0b341b804758 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -45,10 +45,10 @@ 'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'vsplit', 'concatenate', 'append', 'stack', 'vstack', 'column_stack', 'dstack', 'average', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index', 'hanning', 'hamming', - 'blackman', 'flip', 'around', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', - 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', - 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', - 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where', 'bincount'] + 'blackman', 'flip', 'around', 'round', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', + 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', + 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', + 'shares_memory', 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where', 'bincount'] @set_module('mxnet.symbol.numpy') @@ -665,13 +665,13 @@ def norm(self, *args, **kwargs): """ raise AttributeError('_Symbol object has no attribute norm') - def round(self, *args, **kwargs): + def round(self, decimals=0, out=None, **kwargs): # pylint: disable=arguments-differ """Convenience fluent method for :py:func:`round`. The arguments are the same as for :py:func:`round`, with this array as data. """ - raise NotImplementedError + return round(self, decimals=decimals, out=out, **kwargs) def rint(self, *args, **kwargs): """Convenience fluent method for :py:func:`rint`. @@ -4524,6 +4524,24 @@ def around(x, decimals=0, out=None, **kwargs): raise TypeError('type {} not supported'.format(str(type(x)))) +@set_module('mxnet.symbol.numpy') +def round(x, decimals=0, out=None, **kwargs): + r""" + round_(a, decimals=0, out=None) + Round an array to the given number of decimals. + + See Also + -------- + around : equivalent function; see for details. + """ + if isinstance(x, numeric_types): + return _np.around(x, decimals, **kwargs) + elif isinstance(x, _Symbol): + return _npi.around(x, decimals, out=out, **kwargs) + else: + raise TypeError('type {} not supported'.format(str(type(x)))) + + @set_module('mxnet.symbol.numpy') @wrap_np_binary_func def arctan2(x1, x2, out=None, **kwargs): diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 9b445044a3c1..3d26ee28b22e 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -594,6 +594,10 @@ def _add_workload_around(): OpArgMngr.add_workload('around', np.array([1.56, 72.54, 6.35, 3.25]), decimals=1) +def _add_workload_round(): + OpArgMngr.add_workload('round', np.array([1.56, 72.54, 6.35, 3.25]), decimals=1) + + def _add_workload_argsort(): for dtype in [np.int32, np.float32]: a = np.arange(101, dtype=dtype) @@ -1442,6 +1446,7 @@ def _prepare_workloads(): _add_workload_argmin() _add_workload_argmax() _add_workload_around() + _add_workload_round() _add_workload_argsort() _add_workload_append() _add_workload_bincount() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 4bbf9b8040e2..3f9f1d6677cc 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -4435,6 +4435,38 @@ def hybrid_forward(self, F, x): assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) +@with_seed() +@use_np +def test_np_round(): + class TestRound(HybridBlock): + def __init__(self, decimals): + super(TestRound, self).__init__() + self.decimals = decimals + + def hybrid_forward(self, F, x): + return F.np.round(x, self.decimals) + + shapes = [(), (1, 2, 3), (1, 0)] + types = ['int32', 'int64', 'float32', 'float64'] + for hybridize in [True, False]: + for oneType in types: + rtol, atol = 1e-3, 1e-5 + for shape in shapes: + for d in range(-5, 6): + test_round = TestRound(d) + if hybridize: + test_round.hybridize() + x = rand_ndarray(shape, dtype=oneType).as_np_ndarray() + np_out = _np.round(x.asnumpy(), d) + mx_out = test_round(x) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) + + mx_out = np.round(x, d) + np_out = _np.round(x.asnumpy(), d) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) + + @with_seed() @use_np def test_np_nonzero():