diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 2b3b53b2af52..9521dc5aadb5 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -35,7 +35,7 @@ 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', - 'unique', 'lcm', 'tril', 'identity'] + 'unique', 'lcm', 'tril', 'identity', 'take'] @set_module('mxnet.ndarray.numpy') @@ -254,6 +254,91 @@ def identity(n, dtype=None, ctx=None): return _npi.identity(shape=(n, n), ctx=ctx, dtype=dtype) +@set_module('mxnet.ndarray.numpy') +def take(a, indices, axis=None, mode='raise', out=None): + r""" + Take elements from an array along an axis. + + When axis is not None, this function does the same thing as "fancy" + indexing (indexing arrays using arrays); however, it can be easier to use + if you need elements along a given axis. A call such as + ``np.take(arr, indices, axis=3)`` is equivalent to + ``arr[:,:,:,indices,...]``. + + Explained without fancy indexing, this is equivalent to the following use + of `ndindex`, which sets each of ``ii``, ``jj``, and ``kk`` to a tuple of + indices:: + + Ni, Nk = a.shape[:axis], a.shape[axis+1:] + Nj = indices.shape + for ii in ndindex(Ni): + for jj in ndindex(Nj): + for kk in ndindex(Nk): + out[ii + jj + kk] = a[ii + (indices[jj],) + kk] + + Parameters + ---------- + a : ndarray + The source array. + indices : ndarray + The indices of the values to extract. Also allow scalars for indices. + axis : int, optional + The axis over which to select values. By default, the flattened + input array is used. + out : ndarray, optional + If provided, the result will be placed in this array. It should + be of the appropriate shape and dtype. + mode : {'clip', 'wrap'}, optional + Specifies how out-of-bounds indices will behave. + + * 'clip' -- clip to the range (default) + * 'wrap' -- wrap around + + 'clip' mode means that all indices that are too large are replaced + by the index that addresses the last element along that axis. Note + that this disables indexing with negative numbers. + + Returns + ------- + out : ndarray + The returned array has the same type as `a`. + + Notes + ----- + + This function differs from the original `numpy.take + `_ in + the following way(s): + + - Only ndarray or scalar ndarray is accepted as valid input. + + Examples + -------- + >>> a = np.array([4, 3, 5, 7, 6, 8]) + >>> indices = np.array([0, 1, 4]) + >>> np.take(a, indices) + array([4., 3., 6.]) + + In this example for `a` is an ndarray, "fancy" indexing can be used. + + >>> a[indices] + array([4., 3., 6.]) + + If `indices` is not one dimensional, the output also has these dimensions. + + >>> np.take(a, np.array([[0, 1], [2, 3]])) + array([[4., 3.], + [5., 7.]]) + """ + if mode not in ('wrap', 'clip', 'raise'): + raise NotImplementedError( + "function take does not support mode '{}'".format(mode)) + if axis: + return _npi.take(a, indices, axis, mode, out) + else: + return _npi.take(_npi.reshape(a, -1), indices, 0, mode, out) + + #pylint: disable= too-many-arguments, no-member, protected-access def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, out=None): """ Helper function for element-wise operation. diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index a0c25ab5c525..31a829b6a98a 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -54,7 +54,7 @@ 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot', - 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity'] + 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take'] # Return code for dispatching indexing function call _NDARRAY_UNSUPPORTED_INDEXING = -1 @@ -1104,13 +1104,13 @@ def slice_assign(self, rhs, begin, end, step): """ return _npi.slice_assign(self, rhs, begin=begin, end=end, step=step, out=self) - def take(self, *args, **kwargs): + def take(self, indices, axis=None, mode='raise', out=None): """Convenience fluent method for :py:func:`take`. The arguments are the same as for :py:func:`take`, with this array as data. """ - raise NotImplementedError + take(self, indices, axis, mode=mode, out=out) def one_hot(self, *args, **kwargs): """Convenience fluent method for :py:func:`one_hot`. @@ -1890,6 +1890,85 @@ def identity(n, dtype=None, ctx=None): return _mx_nd_np.identity(n, dtype, ctx) +@set_module('mxnet.numpy') +def take(a, indices, axis=None, mode='raise', out=None): + r""" + Take elements from an array along an axis. + + When axis is not None, this function does the same thing as "fancy" + indexing (indexing arrays using arrays); however, it can be easier to use + if you need elements along a given axis. A call such as + ``np.take(arr, indices, axis=3)`` is equivalent to + ``arr[:,:,:,indices,...]``. + + Explained without fancy indexing, this is equivalent to the following use + of `ndindex`, which sets each of ``ii``, ``jj``, and ``kk`` to a tuple of + indices:: + + Ni, Nk = a.shape[:axis], a.shape[axis+1:] + Nj = indices.shape + for ii in ndindex(Ni): + for jj in ndindex(Nj): + for kk in ndindex(Nk): + out[ii + jj + kk] = a[ii + (indices[jj],) + kk] + + Parameters + ---------- + a : ndarray + The source array. + indices : ndarray + The indices of the values to extract. Also allow scalars for indices. + axis : int, optional + The axis over which to select values. By default, the flattened + input array is used. + out : ndarray, optional + If provided, the result will be placed in this array. It should + be of the appropriate shape and dtype. + mode : {'clip', 'wrap'}, optional + Specifies how out-of-bounds indices will behave. + + * 'clip' -- clip to the range (default) + * 'wrap' -- wrap around + + 'clip' mode means that all indices that are too large are replaced + by the index that addresses the last element along that axis. Note + that this disables indexing with negative numbers. + + Returns + ------- + out : ndarray + The returned array has the same type as `a`. + + Notes + ----- + + This function differs from the original `numpy.take + `_ in + the following way(s): + + - Only ndarray or scalar ndarray is accepted as valid input. + + Examples + -------- + >>> a = np.array([4, 3, 5, 7, 6, 8]) + >>> indices = np.array([0, 1, 4]) + >>> np.take(a, indices) + array([4., 3., 6.]) + + In this example for `a` is an ndarray, "fancy" indexing can be used. + + >>> a[indices] + array([4., 3., 6.]) + + If `indices` is not one dimensional, the output also has these dimensions. + + >>> np.take(a, np.array([[0, 1], [2, 3]])) + array([[4., 3.], + [5., 7.]]) + """ + return _mx_nd_np.take(a, indices, axis, mode, out) + + @set_module('mxnet.numpy') def unique(ar, return_index=False, return_inverse=False, return_counts=False, axis=None): """ diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index e3820601e2ec..a8bacbb7feb8 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -37,7 +37,7 @@ 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', - 'unique', 'lcm', 'tril', 'identity'] + 'unique', 'lcm', 'tril', 'identity', 'take'] def _num_outputs(sym): @@ -347,13 +347,13 @@ def slice_like(self, *args, **kwargs): """ raise AttributeError('_Symbol object has no attribute slice_like') - def take(self, *args, **kwargs): + def take(self, indices, axis=None, mode='raise', out=None): """Convenience fluent method for :py:func:`take`. The arguments are the same as for :py:func:`take`, with this array as data. """ - raise NotImplementedError + return take(self, indices, axis, mode=mode, out=out) def one_hot(self, *args, **kwargs): """Convenience fluent method for :py:func:`one_hot`. @@ -1026,6 +1026,72 @@ def identity(n, dtype=None, ctx=None): return _npi.identity(shape=(n, n), ctx=ctx, dtype=dtype) +@set_module('mxnet.symbol.numpy') +def take(a, indices, axis=None, mode='raise', out=None): + r""" + Take elements from an array along an axis. + + When axis is not None, this function does the same thing as "fancy" + indexing (indexing arrays using arrays); however, it can be easier to use + if you need elements along a given axis. A call such as + ``np.take(arr, indices, axis=3)`` is equivalent to + ``arr[:,:,:,indices,...]``. + + Explained without fancy indexing, this is equivalent to the following use + of `ndindex`, which sets each of ``ii``, ``jj``, and ``kk`` to a tuple of + indices:: + + Ni, Nk = a.shape[:axis], a.shape[axis+1:] + Nj = indices.shape + for ii in ndindex(Ni): + for jj in ndindex(Nj): + for kk in ndindex(Nk): + out[ii + jj + kk] = a[ii + (indices[jj],) + kk] + + Parameters + ---------- + a : _Symbol + The source array. + indices : _Symbol + The indices of the values to extract. Also allow scalars for indices. + axis : int, optional + The axis over which to select values. By default, the flattened + input array is used. + out : _Symbol or None, optional + Dummy parameter to keep the consistency with the ndarray counterpart. + mode : {'clip', 'wrap'}, optional + Specifies how out-of-bounds indices will behave. + + * 'clip' -- clip to the range (default) + * 'wrap' -- wrap around + + 'clip' mode means that all indices that are too large are replaced + by the index that addresses the last element along that axis. Note + that this disables indexing with negative numbers. + + Returns + ------- + out : _Symbol + The returned array has the same type as `a`. + + Notes + ----- + + This function differs from the original `numpy.take + `_ in + the following way(s): + + - Only ndarray or scalar ndarray is accepted as valid input. + """ + if mode not in ('wrap', 'clip', 'raise'): + raise NotImplementedError( + "function take does not support mode '{}'".format(mode)) + if axis: + return _npi.take(a, indices, axis, mode, out) + else: + return _npi.take(_npi.reshape(a, -1), indices, 0, mode, out) + + #pylint: disable= too-many-arguments, no-member, protected-access def _ufunc_helper(lhs, rhs, fn_array, fn_scalar, lfn_scalar, rfn_scalar=None, out=None): """ Helper function for element-wise operation. diff --git a/src/operator/tensor/indexing_op.cc b/src/operator/tensor/indexing_op.cc index 463e9f98820e..9961218b5482 100644 --- a/src/operator/tensor/indexing_op.cc +++ b/src/operator/tensor/indexing_op.cc @@ -288,6 +288,10 @@ void TakeOpForward(const nnvm::NodeAttrs& attrs, const mxnet::TShape& arrshape = inputs[take_::kArr].shape_; const mxnet::TShape& oshape = outputs[take_::kOut].shape_; + if (idxshape.Size() == 0) { + return; + } + Stream *s = ctx.get_stream(); const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0); diff --git a/src/operator/tensor/indexing_op.cu b/src/operator/tensor/indexing_op.cu index 9a46d894ee22..77d85d8e1e10 100644 --- a/src/operator/tensor/indexing_op.cu +++ b/src/operator/tensor/indexing_op.cu @@ -479,6 +479,10 @@ void TakeOpForward(const nnvm::NodeAttrs& attrs, const mxnet::TShape& arrshape = inputs[take_::kArr].shape_; const mxnet::TShape& oshape = outputs[take_::kOut].shape_; + if (idxshape.Size() == 0) { + return; + } + Stream *s = ctx.get_stream(); const int actual_axis = param.axis + ((param.axis < 0) ? arrshape.ndim() : 0); diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 161acae0ebf2..16520ddbb242 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -670,9 +670,9 @@ struct TakeParam: public dmlc::Parameter { .set_default(take_::kClip) .describe("Specify how out-of-bound indices bahave. Default is \"clip\"." " \"clip\" means clip to the range. So, if all indices mentioned are too large," - " they are replaced by the index that addresses the last element along an axis. " - " \"wrap\" means to wrap around. " - " \"raise\" means to raise an error, not supported yet."); + " they are replaced by the index that addresses the last element along an axis." + " \"wrap\" means to wrap around." + " \"raise\" means to raise an error when index out of range."); } }; @@ -1030,6 +1030,10 @@ void TakeOpBackward(const nnvm::NodeAttrs& attrs, const mxnet::TShape& arrshape = outputs[0].shape_; const mxnet::TShape& oshape = inputs[0].shape_; + if (idxshape.Size() == 0) { + return; + } + if (req[take_::kIdx] != kNullOp) { mxnet_op::Kernel::Launch( s, idxshape.Size(), outputs[take_::kIdx].dptr()); diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 274e0f16095c..7e3d9655f771 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2704,6 +2704,112 @@ def hybrid_forward(self, F, x1, x2): assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) +@with_seed() +@use_np +def test_np_take(): + configs = [ + ((4, 4), (4, 0), None), + ((4, 4), (4, 0), 0), + ((4, 4), (4, 0), 1), + ((), (4, 0), None), + ((), (5, ), None), + ((), (4, 5), None), + ((), (), None), + ((3, 4), (), None), + ((3, 4), (), 0), + ((3, 4), (), 1), + ((3, 4, 5), (), 2), + ((3, 4, 5), (), -3), + ] + + class TestTake(HybridBlock): + def __init__(self, axis, mode): + super(TestTake, self).__init__() + self._axis = axis + self._mode = mode + + def hybrid_forward(self, F, a, indices): + return F.np.take(a, indices, axis=self._axis, mode=self._mode) + + def grad_helper(grad_in, axis, idx, mode): + k = grad_in.shape[axis] + if mode == 'clip': + idx = 0 if idx < 0 else idx + idx = k - 1 if idx >= k else idx + else: + idx = idx % k + if axis == None: + grad_in[idx] += 1.0 + elif axis == 0: + if axis == len(grad_in.shape) - 1: + grad_in[idx] += 1.0 + else: + grad_in[idx, :] += 1.0 + elif axis == 1: + if axis == len(grad_in.shape) - 1: + grad_in[:, idx] += 1.0 + else: + grad_in[:, idx, :] += 1.0 + elif axis == 2: + if axis == len(grad_in.shape) - 1: + grad_in[:, :, idx] += 1.0 + else: + grad_in[:, :, idx, :] += 1.0 + elif axis == 3: + if axis == len(grad_in.shape) - 1: + grad_in[:, :, :, idx] += 1.0 + else: + grad_in[:, :, :, idx, :] += 1.0 + elif axis == 4: + grad_in[:, :, :, :, idx] += 1.0 + else: + raise ValueError("axis %d is not supported..." % axis) + + def check_output_n_grad(data_shape, idx_shape, axis, mode): + data_real = _np.random.normal(size=data_shape).astype('float32') + idx_real = _np.random.randint(low=-100, high=100, size=idx_shape) + same(np.take(np.array(data_real), np.array(idx_real), axis=axis, mode=mode).asnumpy(), + _np.take(data_real, idx_real, axis=axis, mode=mode)) + + grad_in = _np.zeros(data_shape, dtype='float32') + + test_take = TestTake(axis=axis, mode=mode) + if hybridize: + test_take.hybridize() + x = np.array(data_real) + x.attach_grad() + with mx.autograd.record(): + mx_out = test_take(x, np.array(idx_real)) + same(mx_out.asnumpy(), _np.take(data_real, idx_real, axis=axis, mode=mode)) + + if axis and axis < 0: + axis += len(data_shape) + try: + for i in _np.nditer(idx_real): + grad_helper(grad_in, axis, i, mode) + except: + pass + + mx_out.backward() + same(x.grad.asnumpy(), grad_in) + + for hybridize in [True, False]: + for mode in ['clip', 'wrap']: + for data_ndim in range(1, 5): + for idx_ndim in range(1, 4): + for axis in range(-data_ndim, data_ndim): + data_shape = () + for _ in range(data_ndim): + data_shape += (_np.random.randint(low=1, high=5), ) + idx_shape = () + for _ in range(idx_ndim): + idx_shape += (_np.random.randint(low=1, high=5), ) + check_output_n_grad(data_shape, idx_shape, axis, mode) + + for config in configs: + check_output_n_grad(config[0], config[1], config[2], mode) + + if __name__ == '__main__': import nose nose.runmodule()