Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
numpy op take
Browse files Browse the repository at this point in the history
  • Loading branch information
hgt312 committed Sep 25, 2019
1 parent 883a8b1 commit e78e5f0
Show file tree
Hide file tree
Showing 7 changed files with 358 additions and 10 deletions.
87 changes: 86 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
'unique', 'lcm', 'tril', 'identity']
'unique', 'lcm', 'tril', 'identity', 'take']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -254,6 +254,91 @@ def identity(n, dtype=None, ctx=None):
return _npi.identity(shape=(n, n), ctx=ctx, dtype=dtype)


@set_module('mxnet.ndarray.numpy')
def take(a, indices, axis=None, mode='raise', out=None):
r"""
Take elements from an array along an axis.
When axis is not None, this function does the same thing as "fancy"
indexing (indexing arrays using arrays); however, it can be easier to use
if you need elements along a given axis. A call such as
``np.take(arr, indices, axis=3)`` is equivalent to
``arr[:,:,:,indices,...]``.
Explained without fancy indexing, this is equivalent to the following use
of `ndindex`, which sets each of ``ii``, ``jj``, and ``kk`` to a tuple of
indices::
Ni, Nk = a.shape[:axis], a.shape[axis+1:]
Nj = indices.shape
for ii in ndindex(Ni):
for jj in ndindex(Nj):
for kk in ndindex(Nk):
out[ii + jj + kk] = a[ii + (indices[jj],) + kk]
Parameters
----------
a : ndarray
The source array.
indices : ndarray
The indices of the values to extract. Also allow scalars for indices.
axis : int, optional
The axis over which to select values. By default, the flattened
input array is used.
out : ndarray, optional
If provided, the result will be placed in this array. It should
be of the appropriate shape and dtype.
mode : {'clip', 'wrap'}, optional
Specifies how out-of-bounds indices will behave.
* 'clip' -- clip to the range (default)
* 'wrap' -- wrap around
'clip' mode means that all indices that are too large are replaced
by the index that addresses the last element along that axis. Note
that this disables indexing with negative numbers.
Returns
-------
out : ndarray
The returned array has the same type as `a`.
Notes
-----
This function differs from the original `numpy.take
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.take.html>`_ in
the following way(s):
- Only ndarray or scalar ndarray is accepted as valid input.
Examples
--------
>>> a = np.array([4, 3, 5, 7, 6, 8])
>>> indices = np.array([0, 1, 4])
>>> np.take(a, indices)
array([4., 3., 6.])
In this example for `a` is an ndarray, "fancy" indexing can be used.
>>> a[indices]
array([4., 3., 6.])
If `indices` is not one dimensional, the output also has these dimensions.
>>> np.take(a, np.array([[0, 1], [2, 3]]))
array([[4., 3.],
[5., 7.]])
"""
if mode not in ('wrap', 'clip', 'raise'):
raise NotImplementedError(
"function take does not support mode '{}'".format(mode))
if axis:
return _npi.take(a, indices, axis, mode, out)
else:
return _npi.take(_npi.reshape(a, -1), indices, 0, mode, out)


#pylint: disable= too-many-arguments, no-member, protected-access
def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, out=None):
""" Helper function for element-wise operation.
Expand Down
85 changes: 82 additions & 3 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices',
'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot',
'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity']
'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take']

# Return code for dispatching indexing function call
_NDARRAY_UNSUPPORTED_INDEXING = -1
Expand Down Expand Up @@ -1104,13 +1104,13 @@ def slice_assign(self, rhs, begin, end, step):
"""
return _npi.slice_assign(self, rhs, begin=begin, end=end, step=step, out=self)

def take(self, *args, **kwargs):
def take(self, indices, axis=None, mode='raise', out=None):
"""Convenience fluent method for :py:func:`take`.
The arguments are the same as for :py:func:`take`, with
this array as data.
"""
raise NotImplementedError
take(self, indices, axis, mode=mode, out=out)

def one_hot(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`one_hot`.
Expand Down Expand Up @@ -1890,6 +1890,85 @@ def identity(n, dtype=None, ctx=None):
return _mx_nd_np.identity(n, dtype, ctx)


@set_module('mxnet.numpy')
def take(a, indices, axis=None, mode='raise', out=None):
r"""
Take elements from an array along an axis.
When axis is not None, this function does the same thing as "fancy"
indexing (indexing arrays using arrays); however, it can be easier to use
if you need elements along a given axis. A call such as
``np.take(arr, indices, axis=3)`` is equivalent to
``arr[:,:,:,indices,...]``.
Explained without fancy indexing, this is equivalent to the following use
of `ndindex`, which sets each of ``ii``, ``jj``, and ``kk`` to a tuple of
indices::
Ni, Nk = a.shape[:axis], a.shape[axis+1:]
Nj = indices.shape
for ii in ndindex(Ni):
for jj in ndindex(Nj):
for kk in ndindex(Nk):
out[ii + jj + kk] = a[ii + (indices[jj],) + kk]
Parameters
----------
a : ndarray
The source array.
indices : ndarray
The indices of the values to extract. Also allow scalars for indices.
axis : int, optional
The axis over which to select values. By default, the flattened
input array is used.
out : ndarray, optional
If provided, the result will be placed in this array. It should
be of the appropriate shape and dtype.
mode : {'clip', 'wrap'}, optional
Specifies how out-of-bounds indices will behave.
* 'clip' -- clip to the range (default)
* 'wrap' -- wrap around
'clip' mode means that all indices that are too large are replaced
by the index that addresses the last element along that axis. Note
that this disables indexing with negative numbers.
Returns
-------
out : ndarray
The returned array has the same type as `a`.
Notes
-----
This function differs from the original `numpy.take
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.take.html>`_ in
the following way(s):
- Only ndarray or scalar ndarray is accepted as valid input.
Examples
--------
>>> a = np.array([4, 3, 5, 7, 6, 8])
>>> indices = np.array([0, 1, 4])
>>> np.take(a, indices)
array([4., 3., 6.])
In this example for `a` is an ndarray, "fancy" indexing can be used.
>>> a[indices]
array([4., 3., 6.])
If `indices` is not one dimensional, the output also has these dimensions.
>>> np.take(a, np.array([[0, 1], [2, 3]]))
array([[4., 3.],
[5., 7.]])
"""
return _mx_nd_np.take(a, indices, axis, mode, out)


@set_module('mxnet.numpy')
def unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None):
"""
Expand Down
72 changes: 69 additions & 3 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
'unique', 'lcm', 'tril', 'identity']
'unique', 'lcm', 'tril', 'identity', 'take']


def _num_outputs(sym):
Expand Down Expand Up @@ -347,13 +347,13 @@ def slice_like(self, *args, **kwargs):
"""
raise AttributeError('_Symbol object has no attribute slice_like')

def take(self, *args, **kwargs):
def take(self, indices, axis=None, mode='raise', out=None):
"""Convenience fluent method for :py:func:`take`.
The arguments are the same as for :py:func:`take`, with
this array as data.
"""
raise NotImplementedError
return take(self, indices, axis, mode=mode, out=out)

def one_hot(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`one_hot`.
Expand Down Expand Up @@ -1026,6 +1026,72 @@ def identity(n, dtype=None, ctx=None):
return _npi.identity(shape=(n, n), ctx=ctx, dtype=dtype)


@set_module('mxnet.symbol.numpy')
def take(a, indices, axis=None, mode='raise', out=None):
r"""
Take elements from an array along an axis.
When axis is not None, this function does the same thing as "fancy"
indexing (indexing arrays using arrays); however, it can be easier to use
if you need elements along a given axis. A call such as
``np.take(arr, indices, axis=3)`` is equivalent to
``arr[:,:,:,indices,...]``.
Explained without fancy indexing, this is equivalent to the following use
of `ndindex`, which sets each of ``ii``, ``jj``, and ``kk`` to a tuple of
indices::
Ni, Nk = a.shape[:axis], a.shape[axis+1:]
Nj = indices.shape
for ii in ndindex(Ni):
for jj in ndindex(Nj):
for kk in ndindex(Nk):
out[ii + jj + kk] = a[ii + (indices[jj],) + kk]
Parameters
----------
a : _Symbol
The source array.
indices : _Symbol
The indices of the values to extract. Also allow scalars for indices.
axis : int, optional
The axis over which to select values. By default, the flattened
input array is used.
out : _Symbol or None, optional
Dummy parameter to keep the consistency with the ndarray counterpart.
mode : {'clip', 'wrap'}, optional
Specifies how out-of-bounds indices will behave.
* 'clip' -- clip to the range (default)
* 'wrap' -- wrap around
'clip' mode means that all indices that are too large are replaced
by the index that addresses the last element along that axis. Note
that this disables indexing with negative numbers.
Returns
-------
out : _Symbol
The returned array has the same type as `a`.
Notes
-----
This function differs from the original `numpy.take
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.take.html>`_ in
the following way(s):
- Only ndarray or scalar ndarray is accepted as valid input.
"""
if mode not in ('wrap', 'clip', 'raise'):
raise NotImplementedError(
"function take does not support mode '{}'".format(mode))
if axis:
return _npi.take(a, indices, axis, mode, out)
else:
return _npi.take(_npi.reshape(a, -1), indices, 0, mode, out)


#pylint: disable= too-many-arguments, no-member, protected-access
def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, out=None):
""" Helper function for element-wise operation.
Expand Down
4 changes: 4 additions & 0 deletions src/operator/tensor/indexing_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,10 @@ void TakeOpForward<cpu>(const nnvm::NodeAttrs& attrs,
const mxnet::TShape& arrshape = inputs[take_::kArr].shape_;
const mxnet::TShape& oshape = outputs[take_::kOut].shape_;

if (idxshape.Size() == 0) {
return;
}

Stream<cpu> *s = ctx.get_stream<cpu>();
const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0);

Expand Down
4 changes: 4 additions & 0 deletions src/operator/tensor/indexing_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,10 @@ void TakeOpForward<gpu>(const nnvm::NodeAttrs& attrs,
const mxnet::TShape& arrshape = inputs[take_::kArr].shape_;
const mxnet::TShape& oshape = outputs[take_::kOut].shape_;

if (idxshape.Size() == 0) {
return;
}

Stream<gpu> *s = ctx.get_stream<gpu>();
const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0);

Expand Down
10 changes: 7 additions & 3 deletions src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -670,9 +670,9 @@ struct TakeParam: public dmlc::Parameter<TakeParam> {
.set_default(take_::kClip)
.describe("Specify how out-of-bound indices bahave. Default is \"clip\"."
" \"clip\" means clip to the range. So, if all indices mentioned are too large,"
" they are replaced by the index that addresses the last element along an axis. "
" \"wrap\" means to wrap around. "
" \"raise\" means to raise an error, not supported yet.");
" they are replaced by the index that addresses the last element along an axis."
" \"wrap\" means to wrap around."
" \"raise\" means to raise an error when index out of range.");
}
};

Expand Down Expand Up @@ -1030,6 +1030,10 @@ void TakeOpBackward(const nnvm::NodeAttrs& attrs,
const mxnet::TShape& arrshape = outputs[0].shape_;
const mxnet::TShape& oshape = inputs[0].shape_;

if (idxshape.Size() == 0) {
return;
}

if (req[take_::kIdx] != kNullOp) {
mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch(
s, idxshape.Size(), outputs[take_::kIdx].dptr<IType>());
Expand Down
Loading

0 comments on commit e78e5f0

Please sign in to comment.