Skip to content

Commit

Permalink
flatnonzero (apache#17690)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yiyan66 authored and sxjscience committed Jul 1, 2020
1 parent a7e7fa4 commit b9c6a4a
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 7 deletions.
41 changes: 40 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
'tensordot', 'eye', 'linspace',
'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'hsplit', 'vsplit', 'dsplit',
'concatenate', 'append', 'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack',
'average', 'mean', 'maximum', 'minimum', 'around', 'round', 'round_',
'average', 'mean', 'maximum', 'minimum', 'around', 'round', 'round_', 'flatnonzero',
'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index',
'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr',
'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
Expand Down Expand Up @@ -4989,6 +4989,45 @@ def unravel_index(indices, shape, order='C'): # pylint: disable=redefined-outer-
raise NotImplementedError('Do not support column-major (Fortran-style) order at this moment')


def flatnonzero(a):
r"""
Return indices that are non-zero in the flattened version of a.
This is equivalent to np.nonzero(np.ravel(a))[0].
Parameters
----------
a : array_like
Input data.
Returns
-------
res : ndarray
Output array, containing the indices of the elements of `a.ravel()`
that are non-zero.
See Also
--------
nonzero : Return the indices of the non-zero elements of the input array.
ravel : Return a 1-D array containing the elements of the input array.
Examples
--------
>>> x = np.arange(-2, 3)
>>> x
array([-2, -1, 0, 1, 2])
>>> np.flatnonzero(x)
array([0, 1, 3, 4])
Use the indices of the non-zero elements as an index array to extract
these elements:
>>> x.ravel()[np.flatnonzero(x)]
array([-2, -1, 1, 2])
"""
return nonzero(ravel(a))[0]


def diag_indices_from(arr):
"""
This returns a tuple of indices that can be used to access the main diagonal of an array
Expand Down
2 changes: 0 additions & 2 deletions python/mxnet/numpy/fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
'digitize',
'divmod',
'extract',
'flatnonzero',
'float_power',
'frexp',
'heaviside',
Expand Down Expand Up @@ -124,7 +123,6 @@
digitize = onp.digitize
divmod = onp.divmod
extract = onp.extract
flatnonzero = onp.flatnonzero
float_power = onp.float_power
frexp = onp.frexp
heaviside = onp.heaviside
Expand Down
41 changes: 40 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'histogram',
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'append', 'argsort',
'sort', 'tensordot', 'eye', 'linspace', 'logspace', 'expand_dims', 'tile', 'arange',
'array_split', 'split', 'hsplit', 'vsplit', 'dsplit',
'array_split', 'split', 'hsplit', 'vsplit', 'dsplit', 'flatnonzero',
'concatenate', 'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack',
'average', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'insert',
'indices', 'copysign', 'ravel', 'unravel_index', 'diag_indices_from', 'hanning', 'hamming', 'blackman',
Expand Down Expand Up @@ -6837,6 +6837,45 @@ def unravel_index(indices, shape, order='C'): # pylint: disable=redefined-outer-
return _mx_nd_np.unravel_index(indices, shape, order=order)


def flatnonzero(a):
r"""
Return indices that are non-zero in the flattened version of a.
This is equivalent to np.nonzero(np.ravel(a))[0].
Parameters
----------
a : array_like
Input data.
Returns
-------
res : ndarray
Output array, containing the indices of the elements of `a.ravel()`
that are non-zero.
See Also
--------
nonzero : Return the indices of the non-zero elements of the input array.
ravel : Return a 1-D array containing the elements of the input array.
Examples
--------
>>> x = np.arange(-2, 3)
>>> x
array([-2, -1, 0, 1, 2])
>>> np.flatnonzero(x)
array([0, 1, 3, 4])
Use the indices of the non-zero elements as an index array to extract
these elements:
>>> x.ravel()[np.flatnonzero(x)]
array([-2, -1, 1, 2])
"""
return _mx_nd_np.flatnonzero(a)


def diag_indices_from(arr):
"""
This returns a tuple of indices that can be used to access the main diagonal of an array
Expand Down
1 change: 1 addition & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'transpose',
'unique',
'unravel_index',
'flatnonzero',
'diag_indices_from',
'delete',
'var',
Expand Down
28 changes: 27 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'argsort', 'sort', 'tensordot', 'eye', 'linspace',
'logspace', 'expand_dims', 'tile', 'arange', 'array_split', 'split', 'hsplit', 'vsplit', 'dsplit',
'concatenate', 'append', 'stack', 'vstack', 'row_stack', 'column_stack', 'hstack', 'dstack',
'average', 'mean', 'maximum', 'minimum', 'around', 'round', 'round_',
'average', 'mean', 'maximum', 'minimum', 'around', 'round', 'round_', 'flatnonzero',
'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index',
'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr',
'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
Expand Down Expand Up @@ -4664,6 +4664,32 @@ def unravel_index(indices, shape, order='C'): # pylint: disable=redefined-outer-
raise NotImplementedError('Don not support column-major (Fortran-style) order at this moment')


def flatnonzero(a):
r"""
Return indices that are non-zero in the flattened version of a.
This is equivalent to np.nonzero(np.ravel(a))[0].
Parameters
----------
a : _Symbol
Input data.
Returns
-------
res : _Symbol
Output array, containing the indices of the elements of `a.ravel()`
that are non-zero.
See Also
--------
nonzero : Return the indices of the non-zero elements of the input array.
ravel : Return a 1-D array containing the elements of the input array.
"""
out = _npi.nonzero(ravel(a))
return out.reshape(-1,)


def diag_indices_from(arr):
"""
This returns a tuple of indices that can be used to access the main diagonal of an array
Expand Down
6 changes: 4 additions & 2 deletions tests/python/unittest/test_numpy_interoperability.py
Original file line number Diff line number Diff line change
Expand Up @@ -2222,8 +2222,10 @@ def _add_workload_extract():
OpArgMngr.add_workload('extract', condition, arr)


def _add_workload_flatnonzero():
def _add_workload_flatnonzero(array_pool):
x = np.array([-2, -1, 0, 1, 2])
OpArgMngr.add_workload('flatnonzero', array_pool['4x1'])
OpArgMngr.add_workload('flatnonzero', array_pool['1x2'])
OpArgMngr.add_workload('flatnonzero', x)


Expand Down Expand Up @@ -2911,7 +2913,7 @@ def _prepare_workloads():
_add_workload_digitize()
_add_workload_divmod()
_add_workload_extract()
_add_workload_flatnonzero()
_add_workload_flatnonzero(array_pool)
_add_workload_float_power()
_add_workload_frexp()
_add_workload_histogram2d()
Expand Down
30 changes: 30 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -6456,6 +6456,36 @@ def hybrid_forward(self, F, x):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)


@with_seed()
@use_np
def test_np_flatnonzero():
class TestFlatnonzero(HybridBlock):
def __init__(self):
super(TestFlatnonzero, self).__init__()

def hybrid_forward(self, F, a):
return F.np.flatnonzero(a)

shapes = [(1,), (4, 3), (4, 5), (2, 1), (6, 5, 6), (4, 2, 1, 2),
(5, 1, 3, 3), (3, 3, 1, 0),]
types = ['int32', 'int64', 'float32', 'float64']
hybridizes = [True, False]
for hybridize, oneType, shape in itertools.product(hybridizes, types, shapes):
rtol, atol = 1e-3, 1e-5
test_flatnonzero = TestFlatnonzero()
if hybridize:
test_flatnonzero.hybridize()
x = rand_ndarray(shape, dtype=oneType).as_np_ndarray()
np_out = _np.flatnonzero(x.asnumpy())
mx_out = test_flatnonzero(x)
assert mx_out.shape == np_out.shape
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)

mx_out = np.flatnonzero(x)
np_out = _np.flatnonzero(x.asnumpy())
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol)


@with_seed()
@use_np
def test_np_round():
Expand Down

0 comments on commit b9c6a4a

Please sign in to comment.