diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index ab442743df08..c76f8c6edbc8 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -239,6 +239,7 @@ List of Contributors * [Zhennan Qin](https://github.com/ZhennanQin) * [Zhiyuan Huang](https://github.com/huangzhiyuan) * [Zak Jost](https://github.com/zjost) +* [Nick Guletskii](https://github.com/nickguletskii) * [Shoubhik Bhattacharya](https://github.com/shoubhik) * [Rohit Srivastava](https://github.com/access2rohit) * [Caner Turkmen](https://github.com/canerturkmen) diff --git a/docs/api/python/ndarray/contrib.md b/docs/api/python/ndarray/contrib.md index f60e7f141adf..d4358ddcea22 100644 --- a/docs/api/python/ndarray/contrib.md +++ b/docs/api/python/ndarray/contrib.md @@ -75,6 +75,7 @@ In the rest of this document, we list routines provided by the `ndarray.contrib` isinf isfinite isnan + index_array index_copy getnnz edge_id diff --git a/docs/api/python/symbol/contrib.md b/docs/api/python/symbol/contrib.md index 2a6a5efe29be..38537f7487c7 100644 --- a/docs/api/python/symbol/contrib.md +++ b/docs/api/python/symbol/contrib.md @@ -72,6 +72,7 @@ In the rest of this document, we list routines provided by the `symbol.contrib` foreach while_loop cond + index_array index_copy getnnz edge_id diff --git a/python/mxnet/contrib/amp/lists/symbol.py b/python/mxnet/contrib/amp/lists/symbol.py index 2f8b4f0f9a6a..9c99340ab75d 100644 --- a/python/mxnet/contrib/amp/lists/symbol.py +++ b/python/mxnet/contrib/amp/lists/symbol.py @@ -95,6 +95,7 @@ '_contrib_gradientmultiplier', '_contrib_group_adagrad_update', '_contrib_ifft', + '_contrib_index_array', '_contrib_index_copy', '_contrib_quadratic', '_contrib_quantize', diff --git a/src/operator/contrib/index_array-inl.h b/src/operator/contrib/index_array-inl.h new file mode 100644 index 000000000000..e280d7661b7c --- /dev/null +++ b/src/operator/contrib/index_array-inl.h @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef MXNET_OPERATOR_CONTRIB_INDEX_ARRAY_INL_H_ +#define MXNET_OPERATOR_CONTRIB_INDEX_ARRAY_INL_H_ + +#include +#include +#include "../mshadow_op.h" +#include "../tensor/init_op.h" + +namespace mxnet { +namespace op { + +namespace index_array_enum { +enum IndexArrayOpInputs {kIn}; +enum IndexArrayOpOutputs {kOut}; +enum IndexArrayOpResource {kTempSpace}; +} // namespace index_array_enum + +template +struct IndexArrayKernel { + MSHADOW_XINLINE static void Map(int i, + int64_t* out_data, + const int n, + const int64_t* workspace) { + for (ptrdiff_t j = 0; j < n; j++) { + int64_t upper = workspace[ptrdiff_t(2) * j]; + int64_t lower = workspace[ptrdiff_t(2) * j + ptrdiff_t(1)]; + KERNEL_ASSIGN(out_data[ptrdiff_t(i) * ptrdiff_t(n) + j], req, (i % upper) / lower); + } + } +}; + +template +struct IndexArrayDefaultKernel { + MSHADOW_XINLINE static void Map(int i, + int64_t* out_data, + const int ndim, + const dim_t* shape) { + int64_t index = i; + for (ptrdiff_t j = ndim - 1; j >= 0; j--) { + KERNEL_ASSIGN(out_data[ptrdiff_t(i) * ptrdiff_t(ndim) + j], req, index % shape[j]); + index /= shape[j]; + } + } +}; + +inline std::vector IndexArrayComputeIndexProducts(const TShape &inshape) { + const int ndim = inshape.ndim(); + + std::vector index_products(static_cast(ndim + 1)); + + index_products[ndim] = 1; + + for (int i = ndim - 1; i >= 0; i--) { + index_products[i] = index_products[i + 1] * inshape[i]; + } + + return index_products; +} + +inline void IndexArrayBuildSelectedAxesWorkspace(const mxnet::Tuple &axes, + const std::vector &index_products, + int64_t* workspace, + const int ndim) { + for (int i = 0; i < axes.ndim(); i++) { + // Make sure that the axis is between 0 and ndim. + const int axis = ((axes[i] % ndim) + ndim) % ndim; + + workspace[ptrdiff_t(2) * ptrdiff_t(i)] = index_products[axis]; + workspace[ptrdiff_t(2) * ptrdiff_t(i) + ptrdiff_t(1)] = index_products[axis + 1]; + } +} + +struct IndexArrayParam : public dmlc::Parameter { + dmlc::optional> axes; + DMLC_DECLARE_PARAMETER(IndexArrayParam) { + DMLC_DECLARE_FIELD(axes).set_default(dmlc::optional>()) + .describe("The axes to include in the index array. Supports negative values."); + } +}; // struct IndexArrayParam + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_CONTRIB_INDEX_ARRAY_INL_H_ diff --git a/src/operator/contrib/index_array.cc b/src/operator/contrib/index_array.cc new file mode 100644 index 000000000000..a70dee106314 --- /dev/null +++ b/src/operator/contrib/index_array.cc @@ -0,0 +1,172 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include "./index_array-inl.h" + + +namespace mxnet { +namespace op { + +void IndexArrayForwardCPU(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + const TBlob& in_data = inputs[0]; + const TBlob& out_data = outputs[0]; + + const IndexArrayParam& param = nnvm::get(attrs.parsed); + + const TShape inshape = in_data.shape_; + const int ndim = inshape.ndim(); + + Stream *stream = ctx.get_stream(); + + using namespace mxnet_op; + + if (param.axes.has_value()) { + const mxnet::Tuple& axes = param.axes.value(); + const int naxes = axes.ndim(); + + std::vector index_products = IndexArrayComputeIndexProducts(inshape); + + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(2 * naxes), stream); + + IndexArrayBuildSelectedAxesWorkspace(axes, index_products, workspace.dptr_, ndim); + + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, cpu>::Launch(stream, in_data.Size(), + out_data.dptr(), naxes, workspace.dptr_); + }); + } else { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, cpu>::Launch(stream, in_data.Size(), + out_data.dptr(), ndim, inshape.data()); + }); + } +} + +DMLC_REGISTER_PARAMETER(IndexArrayParam); + +NNVM_REGISTER_OP(_contrib_index_array) +.describe(R"code(Returns an array of indexes of the input array. + +For an input array with shape :math:`(d_1, d_2, ..., d_n)`, `index_array` returns a +:math:`(d_1, d_2, ..., d_n, n)` array `idx`, where +:math:`idx[i_1, i_2, ..., i_n, :] = [i_1, i_2, ..., i_n]`. + +Additionally, when the parameter `axes` is specified, `idx` will be a +:math:`(d_1, d_2, ..., d_n, m)` array where `m` is the length of `axes`, and the following +equality will hold: :math:`idx[i_1, i_2, ..., i_n, j] = i_{axes[j]}`. + +Examples:: + + x = mx.nd.ones((3, 2)) + + mx.nd.contrib.index_array(x) = [[[0 0] + [0 1]] + + [[1 0] + [1 1]] + + [[2 0] + [2 1]]] + + x = mx.nd.ones((3, 2, 2)) + + mx.nd.contrib.index_array(x, axes=(1, 0)) = [[[[0 0] + [0 0]] + + [[1 0] + [1 0]]] + + + [[[0 1] + [0 1]] + + [[1 1] + [1 1]]] + + + [[[0 2] + [0 2]] + + [[1 2] + [1 2]]]] + +)code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs &attrs) { + return std::vector{ "data" }; + }) +.set_attr("FListOutputNames", + [](const NodeAttrs &attrs) { + return std::vector{ "output" }; + }) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", [](const nnvm::NodeAttrs &attrs, + mxnet::ShapeVector *in_shape, + mxnet::ShapeVector *out_shape) { + const IndexArrayParam ¶m = nnvm::get(attrs.parsed); + CHECK_EQ(in_shape->size(), 1U); + CHECK_EQ(out_shape->size(), 1U); + const mxnet::TShape &inshape = (*in_shape)[index_array_enum::kIn]; + if (!mxnet::ndim_is_known(inshape)) return false; + + mxnet::TShape oshape = mxnet::TShape(inshape.ndim() + 1, 0); + + for (int i = 0; i < inshape.ndim(); i++) { + oshape[i] = inshape[i]; + } + if (param.axes.has_value()) { + oshape[inshape.ndim()] = param.axes.value().ndim(); + } else { + oshape[inshape.ndim()] = inshape.ndim(); + } + + SHAPE_ASSIGN_CHECK(*out_shape, 0, oshape); + return shape_is_known(oshape); +}) +.set_attr("FInferType", [](const nnvm::NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64); + return out_attrs->at(0) != -1; +}) +.set_attr("FCompute", IndexArrayForwardCPU) +.set_attr("FGradient", MakeZeroGradNodes) +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +.add_argument("data", "NDArray-or-Symbol", "Input data") +.add_arguments(IndexArrayParam::__FIELDS__()); + + +} // namespace op +} // namespace mxnet + diff --git a/src/operator/contrib/index_array.cu b/src/operator/contrib/index_array.cu new file mode 100644 index 000000000000..ddba6a87309a --- /dev/null +++ b/src/operator/contrib/index_array.cu @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include "./index_array-inl.h" + +namespace mxnet { +namespace op { + +using namespace mshadow::cuda; + +void IndexArrayForwardGPU(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + const TBlob& in_data = inputs[0]; + const TBlob& out_data = outputs[0]; + + const IndexArrayParam& param = nnvm::get(attrs.parsed); + + const TShape inshape = in_data.shape_; + const int ndim = inshape.ndim(); + + Stream *stream = ctx.get_stream(); + + using namespace mxnet_op; + + if (param.axes.has_value()) { + const mxnet::Tuple& axes = param.axes.value(); + const int naxes = axes.ndim(); + + std::vector index_products = IndexArrayComputeIndexProducts(inshape); + + std::vector cpu_workspace(2 * naxes); + IndexArrayBuildSelectedAxesWorkspace(axes, index_products, cpu_workspace.data(), ndim); + + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(2 * naxes), stream); + + CUDA_CALL(cudaMemcpy(workspace.dptr_, cpu_workspace.data(), sizeof(int64_t) * (2 * naxes), + cudaMemcpyHostToDevice)); + + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, gpu>::Launch(stream, in_data.Size(), + out_data.dptr(), naxes, workspace.dptr_); + }); + } else { + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(ndim), stream); + + CUDA_CALL(cudaMemcpy(workspace.dptr_, inshape.data(), sizeof(dim_t) * ndim, + cudaMemcpyHostToDevice)); + + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, gpu>::Launch(stream, in_data.Size(), + out_data.dptr(), ndim, workspace.dptr_); + }); + } +} + +NNVM_REGISTER_OP(_contrib_index_array) +.set_attr("FCompute", IndexArrayForwardGPU); + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index b5aa06964b29..9435f42a547c 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -8427,6 +8427,72 @@ def test_image_normalize(): # check backward using finite difference check_numeric_gradient(img_norm_sym, [data_in_4d], atol=0.001) +@with_seed() +def test_index_array(): + def test_index_array_default(): + for shape in [(10,), (7, 5, 29), (5, 7, 11, 13, 17, 19)]: + data = mx.symbol.Variable("data") + index_array = mx.sym.contrib.index_array(data) + + input_array = np.ones(shape) + mgrid = np.mgrid[tuple(slice(0, x) for x in shape)] + expected = np.stack(mgrid, axis=-1) + + check_symbolic_forward(index_array, [input_array], [expected]) + check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) + + @mx.use_np_compat + def test_index_array_default_zero_dim(): + data = mx.symbol.Variable("data") + index_array = mx.sym.contrib.index_array(data) + + input_array = np.ones(()) + expected = np.zeros((0,)) + + check_symbolic_forward(index_array, [input_array], [expected]) + check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) + + @mx.use_np_compat + def test_index_array_default_zero_size(): + data = mx.symbol.Variable("data") + index_array = mx.sym.contrib.index_array(data) + + input_array = np.ones((0, 0, 0)) + expected = np.zeros((0, 0, 0, 3)) + + check_symbolic_forward(index_array, [input_array], [expected]) + check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) + + def test_index_array_select_axes(): + shape = (5, 7, 11, 13, 17, 19) + for axes in [(3,), (4, 1), (5, 1, 3), (-1,), (-5, -1, -3)]: + data = mx.symbol.Variable("data") + index_array = mx.sym.contrib.index_array(data, axes=axes) + + input_array = np.ones(shape) + mgrid = np.mgrid[tuple(slice(0, x) for x in shape)] + expected = np.stack(mgrid, axis=-1)[..., axes] + + check_symbolic_forward(index_array, [input_array], [expected]) + check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) + + @mx.use_np_compat + def test_index_array_select_axes_zero_size(): + data = mx.symbol.Variable("data") + index_array = mx.sym.contrib.index_array(data, axes=(2, 1)) + + input_array = np.ones((0, 0, 0, 0)) + expected = np.zeros((0, 0, 2)) + + check_symbolic_forward(index_array, [input_array], [expected]) + check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) + + test_index_array_default() + test_index_array_default_zero_dim() + test_index_array_default_zero_size() + test_index_array_select_axes() + test_index_array_select_axes_zero_size() + @with_seed() def test_scalar_tensor_creation():