Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[Numpy] Add sort op
Browse files Browse the repository at this point in the history
* Fix sanity
  • Loading branch information
hanke580 committed Jan 21, 2020
1 parent 425319c commit fcdd9f1
Show file tree
Hide file tree
Showing 7 changed files with 190 additions and 7 deletions.
46 changes: 45 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@
'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs',
'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2',
'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'histogram',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'tensordot', 'eye', 'linspace',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'sort',
'tensordot', 'eye', 'linspace',
'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'vsplit', 'concatenate', 'append',
'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack',
'average', 'mean', 'maximum', 'minimum',
Expand Down Expand Up @@ -1223,6 +1224,49 @@ def argsort(a, axis=-1, kind=None, order=None):
return _npi.argsort(data=a, axis=axis, is_ascend=True, dtype='int64')


@set_module('mxnet.ndarray.numpy')
def sort(a, axis=-1, kind=None, order=None):
"""
Return a sorted copy of an array.
Parameters
----------
a : ndarray
Array to be sorted.
axis : int or None, optional
Axis along which to sort. The default is -1 (the last axis). If None,
the flattened array is used.
kind : string, optional
This argument can take any string, but it does not have any effect on the
final result.
order : str or list of str, optional
Not supported yet, will raise NotImplementedError if not None.
Returns
-------
sorted_array : ndarray
Array of the same type and shape as `a`.
Notes
-----
This operator does not support different sorting algorithms.
--------
Examples
--------
>>> a = np.array([[1,4],[3,1]])
>>> np.sort(a) # sort along the last axis
array([[1, 4],
[1, 3]])
>>> np.sort(a, axis=None) # sort the flattened array
array([1, 1, 3, 4])
>>> np.sort(a, axis=0) # sort along the first axis
array([[1, 1],
[3, 4]])
"""
if order is not None:
raise NotImplementedError("order not supported here")
return _npi.sort(data=a, axis=axis, is_ascend=True)


@set_module('mxnet.ndarray.numpy')
def tensordot(a, b, axes=2):
r"""
Expand Down
47 changes: 44 additions & 3 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log',
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'histogram',
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'append', 'argsort',
'tensordot', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'array_split',
'sort', 'tensordot', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange', 'array_split',
'split', 'vsplit', 'concatenate', 'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack',
'average', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var',
'indices', 'copysign', 'ravel', 'unravel_index', 'hanning', 'hamming', 'blackman', 'flip', 'flipud',
Expand Down Expand Up @@ -1522,13 +1522,13 @@ def pick(self, *args, **kwargs):
"""
raise AttributeError('mxnet.numpy.ndarray object has no attribute pick')

def sort(self, *args, **kwargs):
def sort(self, axis=-1, kind=None, order=None): # pylint: disable=arguments-differ
"""Convenience fluent method for :py:func:`sort`.
The arguments are the same as for :py:func:`sort`, with
this array as data.
"""
raise NotImplementedError
raise sort(self, axis=axis, kind=kind, order=order)

def topk(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`topk`.
Expand Down Expand Up @@ -4635,6 +4635,47 @@ def argsort(a, axis=-1, kind=None, order=None):
return _mx_nd_np.argsort(a, axis=axis, kind=kind, order=order)


@set_module('mxnet.numpy')
def sort(a, axis=-1, kind=None, order=None):
"""
Return a sorted copy of an array.
Parameters
----------
a : ndarray
Array to be sorted.
axis : int or None, optional
Axis along which to sort. The default is -1 (the last axis). If None,
the flattened array is used.
kind : string, optional
This argument can take any string, but it does not have any effect on the
final result.
order : str or list of str, optional
Not supported yet, will raise NotImplementedError if not None.
Returns
-------
sorted_array : ndarray
Array of the same type and shape as `a`.
Notes
-----
This operator does not support different sorting algorithms.
--------
Examples
--------
>>> a = np.array([[1,4],[3,1]])
>>> np.sort(a) # sort along the last axis
array([[1, 4],
[1, 3]])
>>> np.sort(a, axis=None) # sort the flattened array
array([1, 1, 3, 4])
>>> np.sort(a, axis=0) # sort along the first axis
array([[1, 1],
[3, 4]])
"""
return _mx_nd_np.sort(a, axis=axis, kind=kind, order=order)


@set_module('mxnet.numpy')
def tensordot(a, b, axes=2):
r"""
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'around',
'round',
'argsort',
'sort',
'append',
'broadcast_arrays',
'broadcast_to',
Expand Down
34 changes: 31 additions & 3 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp',
'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'histogram',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'tensordot', 'eye', 'linspace',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'sort', 'tensordot', 'eye', 'linspace',
'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'vsplit', 'concatenate', 'append',
'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack',
'average', 'mean', 'maximum', 'minimum',
Expand Down Expand Up @@ -471,13 +471,13 @@ def pick(self, *args, **kwargs):
"""
raise AttributeError('_Symbol object has no attribute pick')

def sort(self, *args, **kwargs):
def sort(self, axis=-1, kind=None, order=None): # pylint: disable=arguments-differ
"""Convenience fluent method for :py:func:`sort`.
The arguments are the same as for :py:func:`sort`, with
this array as data.
"""
raise NotImplementedError
raise sort(self, axis=axis, kind=kind, order=order)

def topk(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`topk`.
Expand Down Expand Up @@ -1624,6 +1624,34 @@ def argsort(a, axis=-1, kind=None, order=None):
return _npi.argsort(data=a, axis=axis, is_ascend=True, dtype='int64')


@set_module('mxnet.symbol.numpy')
def sort(a, axis=-1, kind=None, order=None):
"""
Return a sorted copy of an array.
Parameters
----------
a : _Symbol
Array to be sorted.
axis : int or None, optional
Axis along which to sort. The default is -1 (the last axis). If None,
the flattened array is used.
kind : string, optional
This argument can take any string, but it does not have any effect on the
final result.
order : str or list of str, optional
Not supported yet, will raise NotImplementedError if not None.
Returns
-------
sorted_array : ndarray
Array of the same type and shape as `a`.
Notes
-----
This operator does not support different sorting algorithms.
"""
return _npi.sort(data=a, axis=axis, is_ascend=True)


@set_module('mxnet.symbol.numpy')
def tensordot(a, b, axes=2):
r"""
Expand Down
1 change: 1 addition & 0 deletions src/operator/tensor/ordering_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ NNVM_REGISTER_OP(_backward_topk)
});

NNVM_REGISTER_OP(sort)
.add_alias("_npi_sort")
.describe(R"code(Returns a sorted copy of an input array along the given axis.
Examples::
Expand Down
12 changes: 12 additions & 0 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,17 @@ def _add_workload_argsort():
OpArgMngr.add_workload('argsort', a, axis)


def _add_workload_sort():
OpArgMngr.add_workload('sort', np.random.uniform(0, 100), axis=None)
OpArgMngr.add_workload('sort', np.random.uniform(0, 100, size=()), axis=None)
OpArgMngr.add_workload('sort', np.random.uniform(0, 100, size=(2, 3, 4)), axis=None)
OpArgMngr.add_workload('sort', np.random.uniform(0, 100, size=(4, 3, 0)), axis=None)
OpArgMngr.add_workload('sort', np.random.randint(0, 100, size=(2, 3, 4)), axis=-1)
OpArgMngr.add_workload('sort', np.random.randint(0, 100, size=(4, 3, 5)), axis=-1, kind='mergesort')
OpArgMngr.add_workload('sort', np.random.randint(0, 100, size=(2, 3, 4)), axis=None, kind='quicksort')
OpArgMngr.add_workload('sort', np.random.uniform(0, 100, size=(4, 3, 0)))


def _add_workload_broadcast_arrays(array_pool):
OpArgMngr.add_workload('broadcast_arrays', array_pool['4x1'], array_pool['1x2'])

Expand Down Expand Up @@ -1691,6 +1702,7 @@ def _prepare_workloads():
_add_workload_around()
_add_workload_round()
_add_workload_argsort()
_add_workload_sort()
_add_workload_append()
_add_workload_bincount()
_add_workload_broadcast_arrays(array_pool)
Expand Down
56 changes: 56 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,6 +1406,62 @@ def hybrid_forward(self, F, x):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-5, atol=1e-6, use_broadcast=False)


@with_seed()
@use_np
def test_np_sort():
class TestSort(HybridBlock):
def __init__(self, axis, kind):
super(TestSort, self).__init__()
self._axis = axis
self._kind = kind

def hybrid_forward(self, F, x, *args, **kwargs):
return F.np.sort(x, self._axis, self._kind)

dtypes = [np.int8, np.uint8, np.int32, np.int64, np.float32, np.float64]
shapes = [
(),
(1,),
(5,),
(4, 3),
(3, 5),
(4, 4),
(4, 5),
(5, 5),
(5, 6),
(6, 6),
(0, 1),
(6, 5, 6),
(2, 3, 3, 4),
(4, 2, 1, 2),
(0, 5, 3, 3),
(5, 0, 3, 3),
(3, 3, 0, 0),
]
flags = [True, False]
# Not include 'stable' as some old numpy versions do not support it
kind_list = ['quicksort', 'mergesort', 'heapsort']

for dtype, shape, hybridize, kind in itertools.product(dtypes, shapes, flags, kind_list):
a = np.random.uniform(low=0, high=100, size=shape, dtype='float64').astype(dtype)
axis_list = list(range(len(shape)))
axis_list.append(None)
axis_list.append(-1)
for axis in axis_list:
test = TestSort(axis, kind)
if hybridize:
test.hybridize()
if axis == -1 and len(shape)==0:
continue
ret = test(a)
expected_ret = _np.sort(a.asnumpy(), axis, kind)
assert_almost_equal(ret.asnumpy(), expected_ret, atol=1e-5, rtol=1e-5, use_broadcast=False)

# check imperative again
ret = np.sort(a, axis, kind)
assert_almost_equal(ret.asnumpy(), expected_ret, atol=1e-5, rtol=1e-5, use_broadcast=False)


@with_seed()
@use_np
def test_np_squeeze():
Expand Down

0 comments on commit fcdd9f1

Please sign in to comment.