Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
numpy op take
Browse files Browse the repository at this point in the history
  • Loading branch information
hgt312 committed Aug 12, 2019
1 parent 57927a9 commit 375d821
Show file tree
Hide file tree
Showing 12 changed files with 964 additions and 64 deletions.
85 changes: 84 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from ...context import current_context
from . import _internal as _npi

__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power']
__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'take']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -146,6 +146,89 @@ def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, ou
#pylint: enable= too-many-arguments, no-member, protected-access


@set_module('mxnet.ndarray.numpy')
def take(a, indices, axis=None, mode='clip', out=None):
r"""
Take elements from an array along an axis.
When axis is not None, this function does the same thing as "fancy"
indexing (indexing arrays using arrays); however, it can be easier to use
if you need elements along a given axis. A call such as
``np.take(arr, indices, axis=3)`` is equivalent to
``arr[:,:,:,indices,...]``.
Explained without fancy indexing, this is equivalent to the following use
of `ndindex`, which sets each of ``ii``, ``jj``, and ``kk`` to a tuple of
indices::
Ni, Nk = a.shape[:axis], a.shape[axis+1:]
Nj = indices.shape
for ii in ndindex(Ni):
for jj in ndindex(Nj):
for kk in ndindex(Nk):
out[ii + jj + kk] = a[ii + (indices[jj],) + kk]
Parameters
----------
a : ndarray
The source array.
indices : ndarray
The indices of the values to extract. Also allow scalars for indices.
axis : int, optional
The axis over which to select values. By default, the flattened
input array is used.
out : ndarray, optional
If provided, the result will be placed in this array. It should
be of the appropriate shape and dtype.
mode : {'clip', 'wrap'}, optional
Specifies how out-of-bounds indices will behave.
* 'clip' -- clip to the range (default)
* 'wrap' -- wrap around
'clip' mode means that all indices that are too large are replaced
by the index that addresses the last element along that axis. Note
that this disables indexing with negative numbers.
Returns
-------
out : ndarray
The returned array has the same type as `a`.
Notes
-----
This function differs from the original `numpy.take
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.take.html>`_ in
the following way(s):
- Only ndarray or scalar ndarray is accepted as valid input.
- 'raise' mode is not supported.
Examples
--------
>>> a = np.array([4, 3, 5, 7, 6, 8])
>>> indices = np.array([0, 1, 4])
>>> np.take(a, indices)
array([4., 3., 6.])
In this example for `a` is an ndarray, "fancy" indexing can be used.
>>> a[indices]
array([4., 3., 6.])
If `indices` is not one dimensional, the output also has these dimensions.
>>> np.take(a, np.array([[0, 1], [2, 3]]))
array([[4., 3.],
[5., 7.]])
"""
if mode not in ('wrap', 'clip'):
raise NotImplementedError(
"function take does not support mode '{}'".format(mode))
return _npi.take(a, indices, axis, mode, out)


@set_module('mxnet.ndarray.numpy')
def add(x1, x2, out=None):
"""Add arguments element-wise.
Expand Down
82 changes: 81 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from ..ndarray.numpy import _internal as _npi

__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'add', 'subtract', 'multiply', 'divide',
'mod', 'power']
'mod', 'power', 'take']


# This function is copied from ndarray.py since pylint
Expand Down Expand Up @@ -1405,6 +1405,86 @@ def ones(shape, dtype=_np.float32, order='C', ctx=None):
return _mx_nd_np.ones(shape, dtype, order, ctx)


@set_module('mxnet.numpy')
def take(a, indices, axis=None, mode='clip', out=None):
r"""
Take elements from an array along an axis.
When axis is not None, this function does the same thing as "fancy"
indexing (indexing arrays using arrays); however, it can be easier to use
if you need elements along a given axis. A call such as
``np.take(arr, indices, axis=3)`` is equivalent to
``arr[:,:,:,indices,...]``.
Explained without fancy indexing, this is equivalent to the following use
of `ndindex`, which sets each of ``ii``, ``jj``, and ``kk`` to a tuple of
indices::
Ni, Nk = a.shape[:axis], a.shape[axis+1:]
Nj = indices.shape
for ii in ndindex(Ni):
for jj in ndindex(Nj):
for kk in ndindex(Nk):
out[ii + jj + kk] = a[ii + (indices[jj],) + kk]
Parameters
----------
a : ndarray
The source array.
indices : ndarray
The indices of the values to extract. Also allow scalars for indices.
axis : int, optional
The axis over which to select values. By default, the flattened
input array is used.
out : ndarray, optional
If provided, the result will be placed in this array. It should
be of the appropriate shape and dtype.
mode : {'clip', 'wrap'}, optional
Specifies how out-of-bounds indices will behave.
* 'clip' -- clip to the range (default)
* 'wrap' -- wrap around
'clip' mode means that all indices that are too large are replaced
by the index that addresses the last element along that axis. Note
that this disables indexing with negative numbers.
Returns
-------
out : ndarray
The returned array has the same type as `a`.
Notes
-----
This function differs from the original `numpy.take
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.take.html>`_ in
the following way(s):
- Only ndarray or scalar ndarray is accepted as valid input.
- 'raise' mode is not supported.
Examples
--------
>>> a = np.array([4, 3, 5, 7, 6, 8])
>>> indices = np.array([0, 1, 4])
>>> np.take(a, indices)
array([4., 3., 6.])
In this example for `a` is an ndarray, "fancy" indexing can be used.
>>> a[indices]
array([4., 3., 6.])
If `indices` is not one dimensional, the output also has these dimensions.
>>> np.take(a, np.array([[0, 1], [2, 3]]))
array([[4., 3.],
[5., 7.]])
"""
return _mx_nd_np.take(a, indices, axis, mode, out)


@set_module('mxnet.numpy')
def add(x1, x2, out=None):
"""Add arguments element-wise.
Expand Down
66 changes: 65 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from .._internal import _set_np_symbol_class
from . import _internal as _npi

__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power']
__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'take']


def _num_outputs(sym):
Expand Down Expand Up @@ -1010,4 +1010,68 @@ def power(x1, x2, out=None):
return _ufunc_helper(x1, x2, _npi.power, _np.power, _npi.power_scalar, _npi.rpower_scalar, out)


@set_module('mxnet.symbol.numpy')
def take(a, indices, axis=None, mode='clip', out=None):
r"""
Take elements from an array along an axis.
When axis is not None, this function does the same thing as "fancy"
indexing (indexing arrays using arrays); however, it can be easier to use
if you need elements along a given axis. A call such as
``np.take(arr, indices, axis=3)`` is equivalent to
``arr[:,:,:,indices,...]``.
Explained without fancy indexing, this is equivalent to the following use
of `ndindex`, which sets each of ``ii``, ``jj``, and ``kk`` to a tuple of
indices::
Ni, Nk = a.shape[:axis], a.shape[axis+1:]
Nj = indices.shape
for ii in ndindex(Ni):
for jj in ndindex(Nj):
for kk in ndindex(Nk):
out[ii + jj + kk] = a[ii + (indices[jj],) + kk]
Parameters
----------
a : _Symbol
The source array.
indices : _Symbol
The indices of the values to extract. Also allow scalars for indices.
axis : int, optional
The axis over which to select values. By default, the flattened
input array is used.
out : _Symbol or None, optional
Dummy parameter to keep the consistency with the ndarray counterpart.
mode : {'clip', 'wrap'}, optional
Specifies how out-of-bounds indices will behave.
* 'clip' -- clip to the range (default)
* 'wrap' -- wrap around
'clip' mode means that all indices that are too large are replaced
by the index that addresses the last element along that axis. Note
that this disables indexing with negative numbers.
Returns
-------
out : _Symbol
The returned array has the same type as `a`.
Notes
-----
This function differs from the original `numpy.take
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.take.html>`_ in
the following way(s):
- Only ndarray or scalar ndarray is accepted as valid input.
- 'raise' mode is not supported.
"""
if mode not in ('wrap', 'clip'):
raise NotImplementedError(
"function take does not support mode '{}'".format(mode))
return _npi.take(a, indices, axis, mode, out)


_set_np_symbol_class(_Symbol)
Loading

0 comments on commit 375d821

Please sign in to comment.