diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 964d533b2387..671345c9a546 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -33,7 +33,7 @@ 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean', - 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var'] + 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices'] @set_module('mxnet.ndarray.numpy') @@ -1888,7 +1888,7 @@ def tile(A, reps): """ return _unary_func_helper(A, _npi.tile, _np.tile, reps=reps) - +# pylint: disable=redefined-outer-name @set_module('mxnet.ndarray.numpy') def split(ary, indices_or_sections, axis=0): """Split an array into multiple sub-arrays. @@ -1936,6 +1936,7 @@ def split(ary, indices_or_sections, axis=0): if not isinstance(ret, list): return [ret] return ret +# pylint: enable=redefined-outer-name @set_module('mxnet.ndarray.numpy') @@ -2363,3 +2364,71 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint: 0.2025 """ return _npi.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out) + + +# pylint: disable=redefined-outer-name +@set_module('mxnet.ndarray.numpy') +def indices(dimensions, dtype=_np.int32, ctx=None): + """Return an array representing the indices of a grid. + + Compute an array where the subarrays contain index values 0,1,... + varying only along the corresponding axis. + + Parameters + ---------- + dimensions : sequence of ints + The shape of the grid. + dtype : data-type, optional + The desired data-type for the array. Default is `float32`. + ctx : device context, optional + Device context on which the memory is allocated. Default is + `mxnet.context.current_context()`. + + Returns + ------- + grid : ndarray + The array of grid indices, + ``grid.shape = (len(dimensions),) + tuple(dimensions)``. + + Notes + ----- + The output shape is obtained by prepending the number of dimensions + in front of the tuple of dimensions, i.e. if `dimensions` is a tuple + ``(r0, ..., rN-1)`` of length ``N``, the output shape is + ``(N,r0,...,rN-1)``. + + The subarrays ``grid[k]`` contains the N-D array of indices along the + ``k-th`` axis. Explicitly:: + + grid[k,i0,i1,...,iN-1] = ik + + Examples + -------- + >>> grid = np.indices((2, 3)) + >>> grid.shape + (2, 2, 3) + >>> grid[0] # row indices + array([[0, 0, 0], + [1, 1, 1]]) + >>> grid[1] # column indices + array([[0, 0, 0], + [1, 1, 1]], dtype=int32) + + The indices can be used as an index into an array. + + >>> x = np.arange(20).reshape(5, 4) + >>> row, col = np.indices((2, 3)) + >>> x[row, col] + array([[0., 1., 2.], + [4., 5., 6.]]) + + Note that it would be more straightforward in the above example to + extract the required elements directly with ``x[:2, :3]``. + """ + if isinstance(dimensions, (tuple, list)): + if ctx is None: + ctx = current_context() + return _npi.indices(dimensions=dimensions, dtype=dtype, ctx=ctx) + else: + raise ValueError("The dimensions must be sequence of ints") +# pylint: enable=redefined-outer-name diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 7d6e81a2d4a5..5025aa5cdde8 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -52,7 +52,7 @@ 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', - 'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var'] + 'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices'] # Return code for dispatching indexing function call _NDARRAY_UNSUPPORTED_INDEXING = -1 @@ -3808,3 +3808,64 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=None): 0.2025 """ return _npi.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out) + + +@set_module('mxnet.numpy') +def indices(dimensions, dtype=_np.int32, ctx=None): + """Return an array representing the indices of a grid. + + Compute an array where the subarrays contain index values 0,1,... + varying only along the corresponding axis. + + Parameters + ---------- + dimensions : sequence of ints + The shape of the grid. + dtype : data-type, optional + The desired data-type for the array. Default is `float32`. + ctx : device context, optional + Device context on which the memory is allocated. Default is + `mxnet.context.current_context()`. + + Returns + ------- + grid : ndarray + The array of grid indices, + ``grid.shape = (len(dimensions),) + tuple(dimensions)``. + + Notes + ----- + The output shape is obtained by prepending the number of dimensions + in front of the tuple of dimensions, i.e. if `dimensions` is a tuple + ``(r0, ..., rN-1)`` of length ``N``, the output shape is + ``(N,r0,...,rN-1)``. + + The subarrays ``grid[k]`` contains the N-D array of indices along the + ``k-th`` axis. Explicitly:: + + grid[k,i0,i1,...,iN-1] = ik + + Examples + -------- + >>> grid = np.indices((2, 3)) + >>> grid.shape + (2, 2, 3) + >>> grid[0] # row indices + array([[0, 0, 0], + [1, 1, 1]]) + >>> grid[1] # column indices + array([[0, 0, 0], + [1, 1, 1]], dtype=int32) + + The indices can be used as an index into an array. + + >>> x = np.arange(20).reshape(5, 4) + >>> row, col = np.indices((2, 3)) + >>> x[row, col] + array([[0., 1., 2.], + [4., 5., 6.]]) + + Note that it would be more straightforward in the above example to + extract the required elements directly with ``x[:2, :3]``. + """ + return _mx_nd_np.indices(dimensions=dimensions, dtype=dtype, ctx=ctx) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index ec4f6a4dd741..0841c0e4d2cc 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -35,7 +35,7 @@ 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean', - 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var'] + 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices'] def _num_outputs(sym): @@ -2305,6 +2305,7 @@ def arange(start, stop=None, step=1, dtype=None, ctx=None): return _npi.arange(start=start, stop=stop, step=step, dtype=dtype, ctx=ctx) +# pylint: disable=redefined-outer-name @set_module('mxnet.symbol.numpy') def split(ary, indices_or_sections, axis=0): """Split an array into multiple sub-arrays. @@ -2345,6 +2346,7 @@ def split(ary, indices_or_sections, axis=0): raise ValueError('indices_or_sections must either int or tuple of ints') ret = _npi.split(ary, indices, axis, False, sections) return ret +# pylint: enable=redefined-outer-name @set_module('mxnet.symbol.numpy') @@ -2678,4 +2680,72 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint: return _npi.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out) +# pylint: disable=redefined-outer-name +@set_module('mxnet.symbol.numpy') +def indices(dimensions, dtype=_np.int32, ctx=None): + """Return an array representing the indices of a grid. + + Compute an array where the subarrays contain index values 0,1,... + varying only along the corresponding axis. + + Parameters + ---------- + dimensions : sequence of ints + The shape of the grid. + dtype : data-type, optional + The desired data-type for the array. Default is `float32`. + ctx : device context, optional + Device context on which the memory is allocated. Default is + `mxnet.context.current_context()`. + + Returns + ------- + grid : _Symbol + The array of grid indices, + ``grid.shape = (len(dimensions),) + tuple(dimensions)``. + + Notes + ----- + The output shape is obtained by prepending the number of dimensions + in front of the tuple of dimensions, i.e. if `dimensions` is a tuple + ``(r0, ..., rN-1)`` of length ``N``, the output shape is + ``(N,r0,...,rN-1)``. + + The subarrays ``grid[k]`` contains the N-D array of indices along the + ``k-th`` axis. Explicitly:: + + grid[k,i0,i1,...,iN-1] = ik + + Examples + -------- + >>> grid = np.indices((2, 3)) + >>> grid.shape + (2, 2, 3) + >>> grid[0] # row indices + array([[0, 0, 0], + [1, 1, 1]]) + >>> grid[1] # column indices + array([[0, 0, 0], + [1, 1, 1]], dtype=int32) + + The indices can be used as an index into an array. + + >>> x = np.arange(20).reshape(5, 4) + >>> row, col = np.indices((2, 3)) + >>> x[row, col] + array([[0., 1., 2.], + [4., 5., 6.]]) + + Note that it would be more straightforward in the above example to + extract the required elements directly with ``x[:2, :3]``. + """ + if isinstance(dimensions, (tuple, list)): + if ctx is None: + ctx = current_context() + return _npi.indices(dimensions=dimensions, dtype=dtype, ctx=ctx) + else: + raise ValueError("The dimensions must be sequence of ints") +# pylint: enable=redefined-outer-name + + _set_np_symbol_class(_Symbol) diff --git a/src/operator/numpy/np_init_op.cc b/src/operator/numpy/np_init_op.cc index b85a92f97683..4f031bdaa050 100644 --- a/src/operator/numpy/np_init_op.cc +++ b/src/operator/numpy/np_init_op.cc @@ -24,10 +24,33 @@ */ #include "../tensor/init_op.h" #include "../tensor/elemwise_unary_op.h" +#include "./np_init_op.h" namespace mxnet { namespace op { +DMLC_REGISTER_PARAMETER(IndicesOpParam); + +inline bool NumpyIndicesShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_shapes, + mxnet::ShapeVector* out_shapes) { + const IndicesOpParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(in_shapes->size(), 0U); + CHECK_EQ(out_shapes->size(), 1U); + CHECK_GE(param.dimensions.ndim(), 0) + << "_npi_indices dimensions the number of dim must not be less than 0"; + mxnet::TShape param_dim = param.dimensions; + if (!shape_is_known(param_dim)) return false; + const int indim = param.dimensions.ndim(); + mxnet::TShape ret(indim + 1, -1); + ret[0] = indim; + for (int i = 1; i < indim + 1; ++i) { + ret[i] = param.dimensions[i-1]; + } + SHAPE_ASSIGN_CHECK(*out_shapes, 0, ret); + return shape_is_known(out_shapes->at(0)); +} + NNVM_REGISTER_OP(_npi_zeros) .set_num_inputs(0) .set_num_outputs(1) @@ -110,5 +133,15 @@ NNVM_REGISTER_OP(_npi_arange) .set_attr("FCompute", RangeCompute) .add_arguments(RangeParam::__FIELDS__()); +NNVM_REGISTER_OP(_npi_indices) +.describe("Return an array representing the indices of a grid.") +.set_num_inputs(0) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", NumpyIndicesShape) +.set_attr("FInferType", InitType) +.set_attr("FCompute", IndicesCompute) +.add_arguments(IndicesOpParam::__FIELDS__()); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_init_op.cu b/src/operator/numpy/np_init_op.cu index fe631f388a19..49f1051735d8 100644 --- a/src/operator/numpy/np_init_op.cu +++ b/src/operator/numpy/np_init_op.cu @@ -24,6 +24,7 @@ */ #include "../tensor/init_op.h" +#include "./np_init_op.h" namespace mxnet { namespace op { @@ -43,5 +44,8 @@ NNVM_REGISTER_OP(_np_ones_like) NNVM_REGISTER_OP(_npi_arange) .set_attr("FCompute", RangeCompute); +NNVM_REGISTER_OP(_npi_indices) +.set_attr("FCompute", IndicesCompute); + } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_init_op.h b/src/operator/numpy/np_init_op.h new file mode 100644 index 000000000000..5c41820b57f8 --- /dev/null +++ b/src/operator/numpy/np_init_op.h @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_init_op.h + * \brief CPU Implementation of numpy init op + */ +#ifndef MXNET_OPERATOR_NUMPY_NP_INIT_OP_H_ +#define MXNET_OPERATOR_NUMPY_NP_INIT_OP_H_ + +#include +#include +#include "../tensor/init_op.h" +#include "../tensor/elemwise_unary_op.h" + + +namespace mxnet { +namespace op { + +struct IndicesOpParam : public dmlc::Parameter { + mxnet::TShape dimensions; + int dtype; + std::string ctx; + DMLC_DECLARE_PARAMETER(IndicesOpParam) { + DMLC_DECLARE_FIELD(dimensions) + .describe("The shape of the grid."); + DMLC_DECLARE_FIELD(dtype).set_default(mshadow::kInt32) + MXNET_ADD_ALL_TYPES + .describe("Target data type."); + DMLC_DECLARE_FIELD(ctx) + .set_default("") + .describe("Context of output, in format [cpu|gpu|cpu_pinned](n)." + "Only used for imperative calls."); + } +}; + +template +struct indices_fwd { + template + MSHADOW_XINLINE static void Map(index_t i, DType* out, + const nnvm::dim_t value, + const nnvm::dim_t N, + const nnvm::dim_t dim_i, + const nnvm::dim_t j, + const nnvm::dim_t k, + const nnvm::dim_t t) { + KERNEL_ASSIGN(out[dim_i*N+N/(t*value)*j+i+k*N/t], req, static_cast(j)); + } +}; + +template +void IndicesCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mxnet_op; + CHECK_EQ(inputs.size(), 0U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + const IndicesOpParam& param = nnvm::get(attrs.parsed); + const TBlob& out_data = outputs[0]; + mshadow::Stream *s = ctx.get_stream(); + dim_t indim = param.dimensions.ndim(); + dim_t t = 1; + dim_t N = out_data.Size()/indim; + dim_t value = 0; + if (out_data.Size() == 0) return; + if (req[0] != kNullOp) { + MSHADOW_TYPE_SWITCH(out_data.type_flag_, DType, { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + for (int i = 0; i < indim; ++i) { + value = param.dimensions[i]; + for (int k = 0; k < t; ++k) { + for (int j = 0; j < param.dimensions[i]; ++j) { + Kernel, xpu>::Launch(s, N/(param.dimensions[i] * t), + out_data.dptr(), value, N, i, j, k, t); + } + } + t = t * param.dimensions[i]; + } + }); + }); + } +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_NP_INIT_OP_H_ diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 1c9ebbb32866..3c5021e33a8d 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -30,6 +30,7 @@ import collections import scipy.stats as ss from mxnet.test_utils import verify_generator, gen_buckets_probs_with_ppf, retry +import platform @with_seed() @@ -1664,7 +1665,7 @@ def __init__(self, sample_size, replace): super(TestWeightedChoice, self).__init__() self.sample_size = sample_size self.replace = replace - + def hybrid_forward(self, F, a, p): op = getattr(F.np.random, "choice", None) return F.np.random.choice(a, self.sample_size, self.replace, p) @@ -1703,7 +1704,7 @@ def test_indexing_mode(sampler, set_size, samples_size, replace, weight=None): assert len(samples) == samples_size if not replace: assert len(_np.unique(samples)) == samples_size - + num_classes = 10 num_samples = 10 ** 8 # Density tests are commented out due to their huge time comsumption. @@ -1717,7 +1718,7 @@ def test_indexing_mode(sampler, set_size, samples_size, replace, weight=None): # test_sample_with_replacement(np.random.choice, num_classes, shape) # weight = np.array(_np.random.dirichlet([1.0] * num_classes)) # test_sample_with_replacement(np.random.choice, num_classes, shape, weight) - + # Tests passed locally, # commented out for the same reason as above. # shape_list2 = [ @@ -1730,7 +1731,7 @@ def test_indexing_mode(sampler, set_size, samples_size, replace, weight=None): # test_sample_without_replacement(np.random.choice, num_classes, shape, 10 ** 5) # weight = np.array(_np.random.dirichlet([1.0] * num_classes)) # test_sample_without_replacement(np.random.choice, num_classes, shape, 10 ** 5, weight) - + # Test hypridize mode: for hybridize in [True, False]: for replace in [True, False]: @@ -1744,6 +1745,51 @@ def test_indexing_mode(sampler, set_size, samples_size, replace, weight=None): test_indexing_mode(test_choice_weighted, num_classes, num_classes // 2, replace, weight) +@with_seed() +@use_np +def test_np_indices(): + dtypes = ['int32', 'int64', 'float16', 'float32', 'float64'] + shapes = [ + (0,), + (1,), + (2, 3, 4), + (2, 0, 4), + (1, 1, 1, 1), + (1, 0, 0, 1), + (2, 3, 4, 5, 6, 7, 8) + ] + if platform.system() == 'Windows': + shapes = shapes[1:] #beacuse in numpy windows version, indces not support dimensions is empty tuple. + for dtype in dtypes: + for shape in shapes: + np_out = _np.indices(dimensions=shape, dtype=dtype) + mx_out = np.indices(dimensions=shape, dtype=dtype) + same(mx_out.asnumpy(), np_out) + assert mx_out.shape == np_out.shape + + @use_np + class TestIndices(HybridBlock): + def __init__(self, dimensions=None, dtype=None): + super(TestIndices, self).__init__() + self._dimensions = dimensions + self._dtype = dtype + + def hybrid_forward(self, F, x): + return x + F.np.indices(dimensions=self._dimensions, dtype=self._dtype) + + for dtype in dtypes: + for shape in shapes: + x = np.zeros(shape=(), dtype=dtype) + for hybridize in [False, True]: + net = TestIndices(dimensions=shape, dtype=dtype) + np_out = _np.indices(dimensions=shape, dtype=dtype) + if hybridize: + net.hybridize() + mx_out = net(x) + same(mx_out.asnumpy(), np_out) + assert mx_out.shape == np_out.shape + + if __name__ == '__main__': import nose nose.runmodule()