Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Numpy take operator implementation & bug fix in ndarray.take #15699

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 83 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace', 'eye',
'sin', 'cos', 'sinh', 'cosh', 'log10', 'sqrt', 'abs', 'exp', 'arctan', 'sign', 'log',
'degrees', 'log2', 'rint', 'radians', 'mean', 'reciprocal', 'square', 'arcsin',
'argsort', 'hstack', 'tensordot']
'argsort', 'hstack', 'tensordot', 'take']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -590,6 +590,88 @@ def concatenate(seq, axis=0, out=None):
return _npi.concatenate(*seq, dim=axis, out=out)


@set_module('mxnet.ndarray.numpy')
def take(a, indices, axis=None, mode='clip', out=None):
r"""
Take elements from an array along an axis.

When axis is not None, this function does the same thing as "fancy"
indexing (indexing arrays using arrays); however, it can be easier to use
if you need elements along a given axis. A call such as
``np.take(arr, indices, axis=3)`` is equivalent to
``arr[:,:,:,indices,...]``.

Explained without fancy indexing, this is equivalent to the following use
of `ndindex`, which sets each of ``ii``, ``jj``, and ``kk`` to a tuple of
indices::

Ni, Nk = a.shape[:axis], a.shape[axis+1:]
Nj = indices.shape
for ii in ndindex(Ni):
for jj in ndindex(Nj):
for kk in ndindex(Nk):
out[ii + jj + kk] = a[ii + (indices[jj],) + kk]

Parameters
----------
a : ndarray
The source array.
indices : ndarray
The indices of the values to extract. Also allow scalars for indices.
axis : int, optional
The axis over which to select values. By default, the flattened
input array is used.
out : ndarray, optional
If provided, the result will be placed in this array. It should
be of the appropriate shape and dtype.
mode : {'clip', 'wrap'}, optional
Specifies how out-of-bounds indices will behave.

* 'clip' -- clip to the range (default)
* 'wrap' -- wrap around

'clip' mode means that all indices that are too large are replaced
by the index that addresses the last element along that axis. Note
that this disables indexing with negative numbers.

Returns
-------
out : ndarray
The returned array has the same type as `a`.

Notes
-----

This function differs from the original `numpy.take
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.take.html>`_ in
the following way(s):

- Only ndarray or scalar ndarray is accepted as valid input.
- 'raise' mode is not supported.

Examples
--------
>>> a = np.array([4, 3, 5, 7, 6, 8])
>>> indices = np.array([0, 1, 4])
>>> np.take(a, indices)
array([4., 3., 6.])

In this example for `a` is an ndarray, "fancy" indexing can be used.

>>> a[indices]
array([4., 3., 6.])

If `indices` is not one dimensional, the output also has these dimensions.

>>> np.take(a, np.array([[0, 1], [2, 3]]))
array([[4., 3.],
[5., 7.]])
"""
if mode not in ('wrap', 'clip'):
raise NotImplementedError(
"function take does not support mode '{}'".format(mode))
return _npi.take(a, indices, axis, mode, out)

hgt312 marked this conversation as resolved.
Show resolved Hide resolved
@set_module('mxnet.ndarray.numpy')
def hstack(arrays):
"""
Expand Down
83 changes: 81 additions & 2 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
'clip', 'split', 'swapaxes', 'expand_dims', 'tile', 'linspace', 'eye', 'sin', 'cos',
'sin', 'cos', 'sinh', 'cosh', 'log10', 'sqrt', 'abs', 'exp', 'arctan', 'sign', 'log',
'degrees', 'log2', 'rint', 'radians', 'mean', 'reciprocal', 'square', 'arcsin',
'argsort', 'hstack', 'tensordot']
'argsort', 'hstack', 'tensordot', 'take']


@set_module('mxnet.numpy')
Expand Down Expand Up @@ -1879,6 +1879,85 @@ def concatenate(seq, axis=0, out=None):
return _mx_nd_np.concatenate(seq, axis=axis, out=out)


@set_module('mxnet.numpy')
def take(a, indices, axis=None, mode='clip', out=None):
r"""
Take elements from an array along an axis.

When axis is not None, this function does the same thing as "fancy"
indexing (indexing arrays using arrays); however, it can be easier to use
if you need elements along a given axis. A call such as
``np.take(arr, indices, axis=3)`` is equivalent to
``arr[:,:,:,indices,...]``.

Explained without fancy indexing, this is equivalent to the following use
of `ndindex`, which sets each of ``ii``, ``jj``, and ``kk`` to a tuple of
indices::

Ni, Nk = a.shape[:axis], a.shape[axis+1:]
Nj = indices.shape
for ii in ndindex(Ni):
for jj in ndindex(Nj):
for kk in ndindex(Nk):
out[ii + jj + kk] = a[ii + (indices[jj],) + kk]

Parameters
----------
a : ndarray
The source array.
indices : ndarray
The indices of the values to extract. Also allow scalars for indices.
axis : int, optional
The axis over which to select values. By default, the flattened
input array is used.
out : ndarray, optional
If provided, the result will be placed in this array. It should
be of the appropriate shape and dtype.
mode : {'clip', 'wrap'}, optional
Specifies how out-of-bounds indices will behave.

* 'clip' -- clip to the range (default)
* 'wrap' -- wrap around

'clip' mode means that all indices that are too large are replaced
by the index that addresses the last element along that axis. Note
that this disables indexing with negative numbers.

Returns
-------
out : ndarray
The returned array has the same type as `a`.

Notes
-----

This function differs from the original `numpy.take
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.take.html>`_ in
the following way(s):

- Only ndarray or scalar ndarray is accepted as valid input.
- 'raise' mode is not supported.

Examples
--------
>>> a = np.array([4, 3, 5, 7, 6, 8])
>>> indices = np.array([0, 1, 4])
>>> np.take(a, indices)
array([4., 3., 6.])

In this example for `a` is an ndarray, "fancy" indexing can be used.

>>> a[indices]
array([4., 3., 6.])

If `indices` is not one dimensional, the output also has these dimensions.

>>> np.take(a, np.array([[0, 1], [2, 3]]))
array([[4., 3.],
[5., 7.]])
"""
return _mx_nd_np.take(a, indices, axis, mode, out)

hgt312 marked this conversation as resolved.
Show resolved Hide resolved
@set_module('mxnet.numpy')
def add(x1, x2, out=None):
"""Add arguments element-wise.
Expand Down Expand Up @@ -2911,7 +2990,7 @@ def radians(x, out=None, **kwargs):
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.radians.html>`_ in
the following way(s):

- only ndarray or scalar is accpted as valid input, tuple of ndarray is not supported
- only ndarray or scalar is accepted as valid input, tuple of ndarray is not supported
- broadcasting to `out` of different shape is currently not supported
- when input is plain python numerics, the result will not be stored in the `out` param

Expand Down
66 changes: 65 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
'clip', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'split', 'swapaxes',
'expand_dims', 'tile', 'linspace', 'eye', 'sin', 'cos', 'sinh', 'cosh', 'log10', 'sqrt',
'abs', 'exp', 'arctan', 'sign', 'log', 'degrees', 'log2', 'rint', 'radians', 'mean',
'reciprocal', 'square', 'arcsin', 'argsort', 'hstack', 'tensordot']
'reciprocal', 'square', 'arcsin', 'argsort', 'hstack', 'tensordot', 'take']


def _num_outputs(sym):
Expand Down Expand Up @@ -1255,6 +1255,70 @@ def concatenate(seq, axis=0, out=None):
return _npi.concatenate(*seq, dim=axis, out=out)


@set_module('mxnet.symbol.numpy')
def take(a, indices, axis=None, mode='clip', out=None):
r"""
Take elements from an array along an axis.

When axis is not None, this function does the same thing as "fancy"
indexing (indexing arrays using arrays); however, it can be easier to use
if you need elements along a given axis. A call such as
``np.take(arr, indices, axis=3)`` is equivalent to
``arr[:,:,:,indices,...]``.

Explained without fancy indexing, this is equivalent to the following use
of `ndindex`, which sets each of ``ii``, ``jj``, and ``kk`` to a tuple of
indices::

Ni, Nk = a.shape[:axis], a.shape[axis+1:]
Nj = indices.shape
for ii in ndindex(Ni):
for jj in ndindex(Nj):
for kk in ndindex(Nk):
out[ii + jj + kk] = a[ii + (indices[jj],) + kk]

Parameters
----------
a : _Symbol
The source array.
indices : _Symbol
The indices of the values to extract. Also allow scalars for indices.
axis : int, optional
The axis over which to select values. By default, the flattened
input array is used.
out : _Symbol or None, optional
Dummy parameter to keep the consistency with the ndarray counterpart.
mode : {'clip', 'wrap'}, optional
Specifies how out-of-bounds indices will behave.

* 'clip' -- clip to the range (default)
* 'wrap' -- wrap around

'clip' mode means that all indices that are too large are replaced
by the index that addresses the last element along that axis. Note
that this disables indexing with negative numbers.

Returns
-------
out : _Symbol
The returned array has the same type as `a`.

Notes
-----

This function differs from the original `numpy.take
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.take.html>`_ in
the following way(s):

- Only ndarray or scalar ndarray is accepted as valid input.
- 'raise' mode is not supported.
"""
if mode not in ('wrap', 'clip'):
raise NotImplementedError(
"function take does not support mode '{}'".format(mode))
return _npi.take(a, indices, axis, mode, out)


@set_module('mxnet.symbol.numpy')
def arange(start, stop=None, step=1, dtype=None, ctx=None):
"""Return evenly spaced values within a given interval.
Expand Down
Loading