From c28d8e4a4000e40ac7a9feb74e0a58ae8b895bec Mon Sep 17 00:00:00 2001 From: hanke580 <38852697+hanke580@users.noreply.github.com> Date: Mon, 10 Feb 2020 15:43:05 +0800 Subject: [PATCH] [Numpy] Add sort op (#17393) * [Numpy] Add sort op * Fix sanity * * Fix style * * Add restriction --- python/mxnet/ndarray/numpy/_op.py | 47 +++++++++++++++- python/mxnet/numpy/multiarray.py | 48 +++++++++++++++- python/mxnet/numpy_dispatch_protocol.py | 1 + python/mxnet/symbol/numpy/_symbol.py | 39 ++++++++++++- src/operator/tensor/ordering_op.cc | 1 + .../unittest/test_numpy_interoperability.py | 12 ++++ tests/python/unittest/test_numpy_op.py | 56 +++++++++++++++++++ 7 files changed, 197 insertions(+), 7 deletions(-) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index d6e76fe77b49..69baa06b96de 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -33,7 +33,8 @@ 'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'histogram', - 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'tensordot', 'eye', 'linspace', + 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'sort', + 'tensordot', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'vsplit', 'concatenate', 'append', 'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack', 'average', 'mean', 'maximum', 'minimum', @@ -1224,6 +1225,50 @@ def argsort(a, axis=-1, kind=None, order=None): return _npi.argsort(data=a, axis=axis, is_ascend=True, dtype='int64') +@set_module('mxnet.ndarray.numpy') +def sort(a, axis=-1, kind=None, order=None): + """ + Return a sorted copy of an array. + + Parameters + ---------- + a : ndarray + Array to be sorted. + axis : int or None, optional + Axis along which to sort. The default is -1 (the last axis). If None, + the flattened array is used. + kind : string, optional + This argument can take any string, but it does not have any effect on the + final result. + order : str or list of str, optional + Not supported yet, will raise NotImplementedError if not None. + + Returns + ------- + sorted_array : ndarray + Array of the same type and shape as `a`. + + Notes + ----- + This operator does not support different sorting algorithms. + + Examples + -------- + >>> a = np.array([[1,4],[3,1]]) + >>> np.sort(a) # sort along the last axis + array([[1, 4], + [1, 3]]) + >>> np.sort(a, axis=None) # sort the flattened array + array([1, 1, 3, 4]) + >>> np.sort(a, axis=0) # sort along the first axis + array([[1, 1], + [3, 4]]) + """ + if order is not None: + raise NotImplementedError("order not supported here") + return _npi.sort(data=a, axis=axis, is_ascend=True) + + @set_module('mxnet.ndarray.numpy') def tensordot(a, b, axes=2): r""" diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 967127434b77..7002c7bc8401 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -56,7 +56,7 @@ 'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'histogram', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'append', 'argsort', - 'tensordot', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'array_split', + 'sort', 'tensordot', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'vsplit', 'concatenate', 'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack', 'average', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', @@ -1531,13 +1531,13 @@ def pick(self, *args, **kwargs): """ raise AttributeError('mxnet.numpy.ndarray object has no attribute pick') - def sort(self, *args, **kwargs): + def sort(self, axis=-1, kind=None, order=None): # pylint: disable=arguments-differ """Convenience fluent method for :py:func:`sort`. The arguments are the same as for :py:func:`sort`, with this array as data. """ - raise NotImplementedError + raise sort(self, axis=axis, kind=kind, order=order) def topk(self, *args, **kwargs): """Convenience fluent method for :py:func:`topk`. @@ -4644,6 +4644,48 @@ def argsort(a, axis=-1, kind=None, order=None): return _mx_nd_np.argsort(a, axis=axis, kind=kind, order=order) +@set_module('mxnet.numpy') +def sort(a, axis=-1, kind=None, order=None): + """ + Return a sorted copy of an array. + + Parameters + ---------- + a : ndarray + Array to be sorted. + axis : int or None, optional + Axis along which to sort. The default is -1 (the last axis). If None, + the flattened array is used. + kind : string, optional + This argument can take any string, but it does not have any effect on the + final result. + order : str or list of str, optional + Not supported yet, will raise NotImplementedError if not None. + + Returns + ------- + sorted_array : ndarray + Array of the same type and shape as `a`. + + Notes + ----- + This operator does not support different sorting algorithms. + + Examples + -------- + >>> a = np.array([[1,4],[3,1]]) + >>> np.sort(a) # sort along the last axis + array([[1, 4], + [1, 3]]) + >>> np.sort(a, axis=None) # sort the flattened array + array([1, 1, 3, 4]) + >>> np.sort(a, axis=0) # sort along the first axis + array([[1, 1], + [3, 4]]) + """ + return _mx_nd_np.sort(a, axis=axis, kind=kind, order=order) + + @set_module('mxnet.numpy') def tensordot(a, b, axes=2): r""" diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index 3abf77f7c634..d3fa0d67c08e 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -90,6 +90,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'around', 'round', 'argsort', + 'sort', 'append', 'broadcast_arrays', 'broadcast_to', diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 86dedde85744..1ed34a0660aa 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -41,7 +41,7 @@ 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'histogram', - 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'tensordot', 'eye', 'linspace', + 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'sort', 'tensordot', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'vsplit', 'concatenate', 'append', 'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack', 'average', 'mean', 'maximum', 'minimum', @@ -472,13 +472,13 @@ def pick(self, *args, **kwargs): """ raise AttributeError('_Symbol object has no attribute pick') - def sort(self, *args, **kwargs): + def sort(self, axis=-1, kind=None, order=None): # pylint: disable=arguments-differ """Convenience fluent method for :py:func:`sort`. The arguments are the same as for :py:func:`sort`, with this array as data. """ - raise NotImplementedError + raise sort(self, axis=axis, kind=kind, order=order) def topk(self, *args, **kwargs): """Convenience fluent method for :py:func:`topk`. @@ -1625,6 +1625,39 @@ def argsort(a, axis=-1, kind=None, order=None): return _npi.argsort(data=a, axis=axis, is_ascend=True, dtype='int64') +@set_module('mxnet.symbol.numpy') +def sort(a, axis=-1, kind=None, order=None): + """ + Return a sorted copy of an array. + + Parameters + ---------- + a : _Symbol + Array to be sorted. + axis : int or None, optional + Axis along which to sort. The default is -1 (the last axis). If None, + the flattened array is used. + kind : string, optional + This argument can take any string, but it does not have any effect on the + final result. + order : str or list of str, optional + Not supported yet, will raise NotImplementedError if not None. + + Returns + ------- + sorted_array : ndarray + Array of the same type and shape as `a`. + + Notes + ----- + This operator does not support different sorting algorithms. + """ + if order is not None: + raise NotImplementedError("order is not supported yet...") + + return _npi.sort(data=a, axis=axis, is_ascend=True) + + @set_module('mxnet.symbol.numpy') def tensordot(a, b, axes=2): r""" diff --git a/src/operator/tensor/ordering_op.cc b/src/operator/tensor/ordering_op.cc index 69af70b96cc3..20147e13f4a9 100644 --- a/src/operator/tensor/ordering_op.cc +++ b/src/operator/tensor/ordering_op.cc @@ -107,6 +107,7 @@ NNVM_REGISTER_OP(_backward_topk) }); NNVM_REGISTER_OP(sort) +.add_alias("_npi_sort") .describe(R"code(Returns a sorted copy of an input array along the given axis. Examples:: diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index ad108522a58c..935f46e469ec 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -797,6 +797,17 @@ def _add_workload_argsort(): OpArgMngr.add_workload('argsort', a, axis) +def _add_workload_sort(): + OpArgMngr.add_workload('sort', np.random.uniform(0, 100), axis=None) + OpArgMngr.add_workload('sort', np.random.uniform(0, 100, size=()), axis=None) + OpArgMngr.add_workload('sort', np.random.uniform(0, 100, size=(2, 3, 4)), axis=None) + OpArgMngr.add_workload('sort', np.random.uniform(0, 100, size=(4, 3, 0)), axis=None) + OpArgMngr.add_workload('sort', np.random.randint(0, 100, size=(2, 3, 4)), axis=-1) + OpArgMngr.add_workload('sort', np.random.randint(0, 100, size=(4, 3, 5)), axis=-1, kind='mergesort') + OpArgMngr.add_workload('sort', np.random.randint(0, 100, size=(2, 3, 4)), axis=None, kind='quicksort') + OpArgMngr.add_workload('sort', np.random.uniform(0, 100, size=(4, 3, 0))) + + def _add_workload_broadcast_arrays(array_pool): OpArgMngr.add_workload('broadcast_arrays', array_pool['4x1'], array_pool['1x2']) @@ -1814,6 +1825,7 @@ def _prepare_workloads(): _add_workload_around() _add_workload_round() _add_workload_argsort() + _add_workload_sort() _add_workload_append() _add_workload_bincount() _add_workload_broadcast_arrays(array_pool) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 13817cbabe93..b599d5650f74 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1417,6 +1417,62 @@ def hybrid_forward(self, F, x): assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-5, atol=1e-6, use_broadcast=False) +@with_seed() +@use_np +def test_np_sort(): + class TestSort(HybridBlock): + def __init__(self, axis, kind): + super(TestSort, self).__init__() + self._axis = axis + self._kind = kind + + def hybrid_forward(self, F, x, *args, **kwargs): + return F.np.sort(x, self._axis, self._kind) + + dtypes = [np.int8, np.uint8, np.int32, np.int64, np.float32, np.float64] + shapes = [ + (), + (1,), + (5,), + (4, 3), + (3, 5), + (4, 4), + (4, 5), + (5, 5), + (5, 6), + (6, 6), + (0, 1), + (6, 5, 6), + (2, 3, 3, 4), + (4, 2, 1, 2), + (0, 5, 3, 3), + (5, 0, 3, 3), + (3, 3, 0, 0), + ] + flags = [True, False] + # Not include 'stable' as some old numpy versions do not support it + kind_list = ['quicksort', 'mergesort', 'heapsort'] + + for dtype, shape, hybridize, kind in itertools.product(dtypes, shapes, flags, kind_list): + a = np.random.uniform(low=0, high=100, size=shape, dtype='float64').astype(dtype) + axis_list = list(range(len(shape))) + axis_list.append(None) + axis_list.append(-1) + for axis in axis_list: + test = TestSort(axis, kind) + if hybridize: + test.hybridize() + if axis == -1 and len(shape)==0: + continue + ret = test(a) + expected_ret = _np.sort(a.asnumpy(), axis, kind) + assert_almost_equal(ret.asnumpy(), expected_ret, atol=1e-5, rtol=1e-5, use_broadcast=False) + + # check imperative again + ret = np.sort(a, axis, kind) + assert_almost_equal(ret.asnumpy(), expected_ret, atol=1e-5, rtol=1e-5, use_broadcast=False) + + @with_seed() @use_np def test_np_squeeze():