From a9451cb740e74745d2a9ae2c49800144618dbcf0 Mon Sep 17 00:00:00 2001 From: Nick Guletskii Date: Sun, 7 Apr 2019 14:36:57 +0300 Subject: [PATCH 01/19] Implement the index_array operator --- src/operator/contrib/index_array-inl.h | 110 ++++++++++++++++ src/operator/contrib/index_array.cc | 171 +++++++++++++++++++++++++ src/operator/contrib/index_array.cu | 86 +++++++++++++ 3 files changed, 367 insertions(+) create mode 100644 src/operator/contrib/index_array-inl.h create mode 100644 src/operator/contrib/index_array.cc create mode 100644 src/operator/contrib/index_array.cu diff --git a/src/operator/contrib/index_array-inl.h b/src/operator/contrib/index_array-inl.h new file mode 100644 index 000000000000..c32741a8b991 --- /dev/null +++ b/src/operator/contrib/index_array-inl.h @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef MXNET_OPERATOR_CONTRIB_INDEX_ARRAY_INL_H_ +#define MXNET_OPERATOR_CONTRIB_INDEX_ARRAY_INL_H_ + +#include +#include +#include "../mshadow_op.h" +#include "../tensor/init_op.h" + +namespace mxnet { +namespace op { + +namespace index_array_enum { +enum IndexArrayOpInputs {kIn}; +enum IndexArrayOpOutputs {kOut}; +enum IndexArrayOpResource {kTempSpace}; +} // namespace index_array_enum + +template +struct IndexArrayKernel { + MSHADOW_XINLINE static void Map(size_t i, + int64_t* out_data, + const uint32_t n, + const int64_t* workspace) { + for (uint32_t j = 0; j < n; j++) { + int64_t upper = workspace[2 * j]; + int64_t lower = workspace[2 * j + 1]; + KERNEL_ASSIGN(out_data[i * n + j], req, (i % upper) / lower); + } + } +}; + +template +struct IndexArrayDefaultKernel { + MSHADOW_XINLINE static void Map(size_t i, + int64_t* out_data, + const uint32_t ndim, + const dim_t* shape) { + int64_t index = i; + for (uint32_t j = ndim; j-- > 0;) { + KERNEL_ASSIGN(out_data[i * ndim + j], req, index % shape[j]); + index /= shape[j]; + } + } +}; + +inline std::vector IndexArrayComputeIndexProducts(const TShape &inshape) { + const uint32_t ndim = inshape.ndim(); + + std::vector index_products(ndim + 1); + + index_products[ndim] = 1; + + for (uint32_t i = ndim; i-- > 0;) { + index_products[i] = index_products[i + 1] * inshape[i]; + } + + return index_products; +} + +inline void IndexArrayBuildSelectedAxesWorkspace(const TShape &axes, + const std::vector &index_products, + int64_t* workspace, + const uint32_t ndim) { + for (uint32_t i = 0; i < axes.ndim(); i++) { + // Make sure that the axis is between 0 and ndim. + const dim_t axis = ((axes[i] % ndim) + ndim) % ndim; + + workspace[2 * i] = index_products[axis]; + workspace[2 * i + 1] = index_products[axis + 1]; + } +} + +template +void IndexArrayForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs); + +struct IndexArrayParam : public dmlc::Parameter { + dmlc::optional axes; + DMLC_DECLARE_PARAMETER(IndexArrayParam) { + DMLC_DECLARE_FIELD(axes).set_default(dmlc::optional()) + .describe("The axes to include in the index array. Supports negative values."); + } +}; // struct IndexArrayParam + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_CONTRIB_INDEX_ARRAY_INL_H_ diff --git a/src/operator/contrib/index_array.cc b/src/operator/contrib/index_array.cc new file mode 100644 index 000000000000..d3faf4122949 --- /dev/null +++ b/src/operator/contrib/index_array.cc @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include "./index_array-inl.h" + + +namespace mxnet { +namespace op { + +template<> +void IndexArrayForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + const TBlob& in_data = inputs[0]; + const TBlob& out_data = outputs[0]; + + const IndexArrayParam& param = nnvm::get(attrs.parsed); + + const TShape inshape = in_data.shape_; + const uint32_t ndim = inshape.ndim(); + + Stream *stream = ctx.get_stream(); + + using namespace mxnet_op; + + if (param.axes.has_value()) { + const TShape& axes = param.axes.value(); + const uint32_t naxes = axes.ndim(); + + std::vector index_products = IndexArrayComputeIndexProducts(inshape); + + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(2 * naxes), stream); + + IndexArrayBuildSelectedAxesWorkspace(axes, index_products, workspace.dptr_, ndim); + + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, cpu>::Launch(stream, in_data.Size(), + out_data.dptr(), naxes, workspace.dptr_); + }); + } else { + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, cpu>::Launch(stream, in_data.Size(), + out_data.dptr(), ndim, inshape.data()); + }); + } +} + +DMLC_REGISTER_PARAMETER(IndexArrayParam); + +NNVM_REGISTER_OP(_contrib_index_array) +.describe(R"code(Returns an array of indexes of the input array. + +For an input array with shape :math:`(d_1, d_2, ..., d_n)`, `index_array` returns a +:math:`(d_1, d_2, ..., d_n, n)` array `idx`, where +:math:`idx[i_1, i_2, ..., i_n, :] = [i_1, i_2, ..., i_n]`. + +Additionally, when the parameter `axes` is specified, `idx` will be a +:math:`(d_1, d_2, ..., d_n, m)` array where `m` is the length of `axes`, and the following +equality will hold: :math:`idx[i_1, i_2, ..., i_n, j] = i_{axes[j]}`. + +Examples:: + + x = mx.nd.ones((3, 2)) + + mx.nd.contrib.index_array(x) = [[[0 0] + [0 1]] + + [[1 0] + [1 1]] + + [[2 0] + [2 1]]] + + x = mx.nd.ones((3, 2, 2)) + + mx.nd.contrib.index_array(x, axes=(1, 0)) = [[[[0 0] + [0 0]] + + [[1 0] + [1 0]]] + + + [[[0 1] + [0 1]] + + [[1 1] + [1 1]]] + + + [[[0 2] + [0 2]] + + [[1 2] + [1 2]]]] + +)code" ADD_FILELINE) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs &attrs) { + return std::vector{ "data" }; + }) +.set_attr("FListOutputNames", + [](const NodeAttrs &attrs) { + return std::vector{ "output" }; + }) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", [](const nnvm::NodeAttrs &attrs, + mxnet::ShapeVector *in_shape, + mxnet::ShapeVector *out_shape) { + const IndexArrayParam ¶m = nnvm::get(attrs.parsed); + CHECK_EQ(in_shape->size(), 1U); + CHECK_EQ(out_shape->size(), 1U); + mxnet::TShape inshape = in_shape->at(index_array_enum::kIn); + mxnet::TShape oshape = mxnet::TShape(inshape.ndim() + 1U); + + for (size_t i = 0; i < inshape.ndim(); i++) { + oshape[i] = inshape[i]; + } + if (param.axes.has_value()) { + oshape[inshape.ndim()] = param.axes.value().ndim(); + } else { + oshape[inshape.ndim()] = inshape.ndim(); + } + out_shape->clear(); + out_shape->push_back(oshape); + return true; +}) +.set_attr("FInferType", [](const nnvm::NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64); + return out_attrs->at(0) != -1; +}) +.set_attr("FCompute", IndexArrayForward) +.set_attr("FGradient", MakeZeroGradNodes) +.set_attr("FResourceRequest", [](const NodeAttrs& n) { + return std::vector{ResourceRequest::kTempSpace}; +}) +.add_argument("data", "NDArray-or-Symbol", "Input data") +.add_arguments(IndexArrayParam::__FIELDS__()); + + +} // namespace op +} // namespace mxnet + diff --git a/src/operator/contrib/index_array.cu b/src/operator/contrib/index_array.cu new file mode 100644 index 000000000000..5cba11166d7b --- /dev/null +++ b/src/operator/contrib/index_array.cu @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include "./index_array-inl.h" + +namespace mxnet { +namespace op { + +using namespace mshadow::cuda; + +template<> +void IndexArrayForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + const TBlob& in_data = inputs[0]; + const TBlob& out_data = outputs[0]; + + const IndexArrayParam& param = nnvm::get(attrs.parsed); + + const TShape inshape = in_data.shape_; + const uint32_t ndim = inshape.ndim(); + + Stream *stream = ctx.get_stream(); + + using namespace mxnet_op; + + if (param.axes.has_value()) { + const TShape& axes = param.axes.value(); + const uint32_t naxes = axes.ndim(); + + std::vector index_products = IndexArrayComputeIndexProducts(inshape); + + std::vector cpu_workspace(2 * naxes); + IndexArrayBuildSelectedAxesWorkspace(axes, index_products, cpu_workspace.data(), ndim); + + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(2 * naxes), stream); + + CUDA_CALL(cudaMemcpy(workspace.dptr_, cpu_workspace.data(), sizeof(int64_t) * (2 * naxes), + cudaMemcpyHostToDevice)); + + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, gpu>::Launch(stream, in_data.Size(), + out_data.dptr(), naxes, workspace.dptr_); + }); + } else { + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(ndim), stream); + + CUDA_CALL(cudaMemcpy(workspace.dptr_, inshape.data(), sizeof(dim_t) * (ndim), + cudaMemcpyHostToDevice)); + + MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { + Kernel, gpu>::Launch(stream, in_data.Size(), + out_data.dptr(), ndim, workspace.dptr_); + }); + } +} + +NNVM_REGISTER_OP(_contrib_index_array) +.set_attr("FCompute", IndexArrayForward); + +} // namespace op +} // namespace mxnet From 91e20525f71b0f909350b2581912404aaa91b6f9 Mon Sep 17 00:00:00 2001 From: Nick Guletskii Date: Sun, 7 Apr 2019 14:38:15 +0300 Subject: [PATCH 02/19] Add index_array operator tests --- tests/python/unittest/test_operator.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index b5aa06964b29..6ea78fd3d89b 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -8427,6 +8427,31 @@ def test_image_normalize(): # check backward using finite difference check_numeric_gradient(img_norm_sym, [data_in_4d], atol=0.001) +@with_seed() +def test_index_array_default(): + for shape in [(10,), (7, 5, 29), (5, 7, 11, 13, 17, 19)]: + data = mx.symbol.Variable("data") + index_array = mx.sym.contrib.index_array(data) + + mgrid = np.mgrid[tuple(slice(0, x) for x in shape)] + expected = np.stack(mgrid, axis=-1) + + check_symbolic_forward(index_array, [np.ones(shape)], [expected]) + check_symbolic_backward(index_array, [np.ones(shape)], [np.ones(shape)], [np.zeros(shape)]) + +@with_seed() +def test_index_array_select_axes(): + shape = (5, 7, 11, 13, 17, 19) + for axes in [(3,), (4, 1), (5, 1, 3), (-1,), (-5, -1, -3)]: + data = mx.symbol.Variable("data") + index_array = mx.sym.contrib.index_array(data, axes=axes) + + mgrid = np.mgrid[tuple(slice(0, x) for x in shape)] + expected = np.stack(mgrid, axis=-1)[..., axes] + + check_symbolic_forward(index_array, [np.ones(shape)], [expected]) + check_symbolic_backward(index_array, [np.ones(shape)], [np.ones(shape)], [np.zeros(shape)]) + @with_seed() def test_scalar_tensor_creation(): From 1a4a0b6d185e8053af05f0a532b4d7f97541af88 Mon Sep 17 00:00:00 2001 From: Nick Guletskii Date: Sun, 7 Apr 2019 14:38:38 +0300 Subject: [PATCH 03/19] Add index_array operator GPU tests --- tests/python/gpu/test_operator_gpu.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 2a1583ed639e..d0430a131c39 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -2130,6 +2130,30 @@ def test_bilinear_sampler_versions(): if req_dict['grid'] is 'write': assert_almost_equal(exe.grad_dict['grid'].asnumpy(), exe_list[ref_idx].grad_dict['grid'].asnumpy(), rtol=1e-3, atol=1e-5) +@with_seed() +def test_index_array_default(): + for shape in [(10,), (7, 5, 29), (5, 7, 11, 13, 17, 19)]: + data = mx.symbol.Variable("data") + index_array = mx.sym.contrib.index_array(data) + + mgrid = np.mgrid[tuple(slice(0, x) for x in shape)] + expected = np.stack(mgrid, axis=-1) + + check_symbolic_forward(index_array, [np.ones(shape)], [expected]) + check_symbolic_backward(index_array, [np.ones(shape)], [np.ones(shape)], [np.zeros(shape)]) + +@with_seed() +def test_index_array_select_axes(): + shape = (5, 7, 11, 13, 17, 19) + for axes in [(3,), (4, 1), (5, 1, 3), (-1,), (-5, -1, -3)]: + data = mx.symbol.Variable("data") + index_array = mx.sym.contrib.index_array(data, axes=axes) + + mgrid = np.mgrid[tuple(slice(0, x) for x in shape)] + expected = np.stack(mgrid, axis=-1)[..., axes] + + check_symbolic_forward(index_array, [np.ones(shape)], [expected]) + check_symbolic_backward(index_array, [np.ones(shape)], [np.ones(shape)], [np.zeros(shape)]) # isolated execution bulking test function to be invoked with different env var settings def _test_bulking_in_process(seed, time_per_iteration): From c58964a999a726ef85907244a18d24cc42d47c1a Mon Sep 17 00:00:00 2001 From: Nick Guletskii Date: Sun, 7 Apr 2019 14:39:31 +0300 Subject: [PATCH 04/19] Add the index_array operator to the Python docs autosummary --- docs/api/python/ndarray/contrib.md | 1 + docs/api/python/symbol/contrib.md | 1 + 2 files changed, 2 insertions(+) diff --git a/docs/api/python/ndarray/contrib.md b/docs/api/python/ndarray/contrib.md index f60e7f141adf..d4358ddcea22 100644 --- a/docs/api/python/ndarray/contrib.md +++ b/docs/api/python/ndarray/contrib.md @@ -75,6 +75,7 @@ In the rest of this document, we list routines provided by the `ndarray.contrib` isinf isfinite isnan + index_array index_copy getnnz edge_id diff --git a/docs/api/python/symbol/contrib.md b/docs/api/python/symbol/contrib.md index 2a6a5efe29be..38537f7487c7 100644 --- a/docs/api/python/symbol/contrib.md +++ b/docs/api/python/symbol/contrib.md @@ -72,6 +72,7 @@ In the rest of this document, we list routines provided by the `symbol.contrib` foreach while_loop cond + index_array index_copy getnnz edge_id From 50c0c2853ed1e1a4d8e610cabbeb6932c34592b1 Mon Sep 17 00:00:00 2001 From: Nick Guletskii Date: Sun, 7 Apr 2019 15:15:09 +0300 Subject: [PATCH 05/19] Add the author of the index_array operator to CONTRIBUTORS.md --- CONTRIBUTORS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index ab442743df08..c76f8c6edbc8 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -239,6 +239,7 @@ List of Contributors * [Zhennan Qin](https://github.com/ZhennanQin) * [Zhiyuan Huang](https://github.com/huangzhiyuan) * [Zak Jost](https://github.com/zjost) +* [Nick Guletskii](https://github.com/nickguletskii) * [Shoubhik Bhattacharya](https://github.com/shoubhik) * [Rohit Srivastava](https://github.com/access2rohit) * [Caner Turkmen](https://github.com/canerturkmen) From df1f42150645d8e0c2cbc8205a0585627657c856 Mon Sep 17 00:00:00 2001 From: Nick Guletskii Date: Tue, 30 Apr 2019 21:28:16 +0300 Subject: [PATCH 06/19] Make index_array compatible with zero-dim and zero-size arrays Changes the implementation of index_array to be compatible with the recently merged support for zero-dim and zero-size arrays. Resolves the incompatibilities with #14661. --- src/operator/contrib/index_array-inl.h | 22 +++++++++++----------- src/operator/contrib/index_array.cc | 8 ++++---- src/operator/contrib/index_array.cu | 6 +++--- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/operator/contrib/index_array-inl.h b/src/operator/contrib/index_array-inl.h index c32741a8b991..05f4265cd63e 100644 --- a/src/operator/contrib/index_array-inl.h +++ b/src/operator/contrib/index_array-inl.h @@ -36,11 +36,11 @@ enum IndexArrayOpResource {kTempSpace}; template struct IndexArrayKernel { - MSHADOW_XINLINE static void Map(size_t i, + MSHADOW_XINLINE static void Map(int i, int64_t* out_data, - const uint32_t n, + const int n, const int64_t* workspace) { - for (uint32_t j = 0; j < n; j++) { + for (int j = 0; j < n; j++) { int64_t upper = workspace[2 * j]; int64_t lower = workspace[2 * j + 1]; KERNEL_ASSIGN(out_data[i * n + j], req, (i % upper) / lower); @@ -50,12 +50,12 @@ struct IndexArrayKernel { template struct IndexArrayDefaultKernel { - MSHADOW_XINLINE static void Map(size_t i, + MSHADOW_XINLINE static void Map(int i, int64_t* out_data, - const uint32_t ndim, + const int ndim, const dim_t* shape) { int64_t index = i; - for (uint32_t j = ndim; j-- > 0;) { + for (int j = ndim - 1; j >= 0; j--) { KERNEL_ASSIGN(out_data[i * ndim + j], req, index % shape[j]); index /= shape[j]; } @@ -63,13 +63,13 @@ struct IndexArrayDefaultKernel { }; inline std::vector IndexArrayComputeIndexProducts(const TShape &inshape) { - const uint32_t ndim = inshape.ndim(); + const int ndim = inshape.ndim(); - std::vector index_products(ndim + 1); + std::vector index_products(static_cast(ndim + 1)); index_products[ndim] = 1; - for (uint32_t i = ndim; i-- > 0;) { + for (int i = ndim - 1; i >= 0; i--) { index_products[i] = index_products[i + 1] * inshape[i]; } @@ -79,8 +79,8 @@ inline std::vector IndexArrayComputeIndexProducts(const TShape &inshape inline void IndexArrayBuildSelectedAxesWorkspace(const TShape &axes, const std::vector &index_products, int64_t* workspace, - const uint32_t ndim) { - for (uint32_t i = 0; i < axes.ndim(); i++) { + const int ndim) { + for (int i = 0; i < axes.ndim(); i++) { // Make sure that the axis is between 0 and ndim. const dim_t axis = ((axes[i] % ndim) + ndim) % ndim; diff --git a/src/operator/contrib/index_array.cc b/src/operator/contrib/index_array.cc index d3faf4122949..9c29cdcb4c68 100644 --- a/src/operator/contrib/index_array.cc +++ b/src/operator/contrib/index_array.cc @@ -39,7 +39,7 @@ void IndexArrayForward(const nnvm::NodeAttrs &attrs, const IndexArrayParam& param = nnvm::get(attrs.parsed); const TShape inshape = in_data.shape_; - const uint32_t ndim = inshape.ndim(); + const int ndim = inshape.ndim(); Stream *stream = ctx.get_stream(); @@ -47,7 +47,7 @@ void IndexArrayForward(const nnvm::NodeAttrs &attrs, if (param.axes.has_value()) { const TShape& axes = param.axes.value(); - const uint32_t naxes = axes.ndim(); + const int naxes = axes.ndim(); std::vector index_products = IndexArrayComputeIndexProducts(inshape); @@ -135,9 +135,9 @@ Examples:: CHECK_EQ(in_shape->size(), 1U); CHECK_EQ(out_shape->size(), 1U); mxnet::TShape inshape = in_shape->at(index_array_enum::kIn); - mxnet::TShape oshape = mxnet::TShape(inshape.ndim() + 1U); + mxnet::TShape oshape = mxnet::TShape(inshape.ndim() + 1, 0); - for (size_t i = 0; i < inshape.ndim(); i++) { + for (int i = 0; i < inshape.ndim(); i++) { oshape[i] = inshape[i]; } if (param.axes.has_value()) { diff --git a/src/operator/contrib/index_array.cu b/src/operator/contrib/index_array.cu index 5cba11166d7b..0d39399cdab1 100644 --- a/src/operator/contrib/index_array.cu +++ b/src/operator/contrib/index_array.cu @@ -40,7 +40,7 @@ void IndexArrayForward(const nnvm::NodeAttrs &attrs, const IndexArrayParam& param = nnvm::get(attrs.parsed); const TShape inshape = in_data.shape_; - const uint32_t ndim = inshape.ndim(); + const int ndim = inshape.ndim(); Stream *stream = ctx.get_stream(); @@ -48,7 +48,7 @@ void IndexArrayForward(const nnvm::NodeAttrs &attrs, if (param.axes.has_value()) { const TShape& axes = param.axes.value(); - const uint32_t naxes = axes.ndim(); + const int naxes = axes.ndim(); std::vector index_products = IndexArrayComputeIndexProducts(inshape); @@ -69,7 +69,7 @@ void IndexArrayForward(const nnvm::NodeAttrs &attrs, Tensor workspace = ctx.requested[0].get_space_typed(Shape1(ndim), stream); - CUDA_CALL(cudaMemcpy(workspace.dptr_, inshape.data(), sizeof(dim_t) * (ndim), + CUDA_CALL(cudaMemcpy(workspace.dptr_, inshape.data(), sizeof(dim_t) * ndim, cudaMemcpyHostToDevice)); MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, { From e4152c64ab31009e9a6663f51eb8aafc117dad08 Mon Sep 17 00:00:00 2001 From: Nick Guletskii Date: Tue, 30 Apr 2019 21:33:44 +0300 Subject: [PATCH 07/19] Fix the index_array gradient checks in the unit tests In the previous implementation, the output gradient had an incorrect shape. This commit fixes the shapes and makes the tests more readable. --- tests/python/gpu/test_operator_gpu.py | 12 +++++++----- tests/python/unittest/test_operator.py | 13 +++++++------ 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index d0430a131c39..b147f8f6665c 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -2136,11 +2136,12 @@ def test_index_array_default(): data = mx.symbol.Variable("data") index_array = mx.sym.contrib.index_array(data) + input_array = np.ones(shape) mgrid = np.mgrid[tuple(slice(0, x) for x in shape)] expected = np.stack(mgrid, axis=-1) - - check_symbolic_forward(index_array, [np.ones(shape)], [expected]) - check_symbolic_backward(index_array, [np.ones(shape)], [np.ones(shape)], [np.zeros(shape)]) + + check_symbolic_forward(index_array, [input_array], [expected]) + check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) @with_seed() def test_index_array_select_axes(): @@ -2149,11 +2150,12 @@ def test_index_array_select_axes(): data = mx.symbol.Variable("data") index_array = mx.sym.contrib.index_array(data, axes=axes) + input_array = np.ones(shape) mgrid = np.mgrid[tuple(slice(0, x) for x in shape)] expected = np.stack(mgrid, axis=-1)[..., axes] - check_symbolic_forward(index_array, [np.ones(shape)], [expected]) - check_symbolic_backward(index_array, [np.ones(shape)], [np.ones(shape)], [np.zeros(shape)]) + check_symbolic_forward(index_array, [input_array], [expected]) + check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) # isolated execution bulking test function to be invoked with different env var settings def _test_bulking_in_process(seed, time_per_iteration): diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 6ea78fd3d89b..66780b289730 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -8433,11 +8433,12 @@ def test_index_array_default(): data = mx.symbol.Variable("data") index_array = mx.sym.contrib.index_array(data) + input_array = np.ones(shape) mgrid = np.mgrid[tuple(slice(0, x) for x in shape)] expected = np.stack(mgrid, axis=-1) - - check_symbolic_forward(index_array, [np.ones(shape)], [expected]) - check_symbolic_backward(index_array, [np.ones(shape)], [np.ones(shape)], [np.zeros(shape)]) + + check_symbolic_forward(index_array, [input_array], [expected]) + check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) @with_seed() def test_index_array_select_axes(): @@ -8446,12 +8447,12 @@ def test_index_array_select_axes(): data = mx.symbol.Variable("data") index_array = mx.sym.contrib.index_array(data, axes=axes) + input_array = np.ones(shape) mgrid = np.mgrid[tuple(slice(0, x) for x in shape)] expected = np.stack(mgrid, axis=-1)[..., axes] - check_symbolic_forward(index_array, [np.ones(shape)], [expected]) - check_symbolic_backward(index_array, [np.ones(shape)], [np.ones(shape)], [np.zeros(shape)]) - + check_symbolic_forward(index_array, [input_array], [expected]) + check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) @with_seed() def test_scalar_tensor_creation(): From adb9f36277f40cceec655700c94520e3f6b5bbfb Mon Sep 17 00:00:00 2001 From: Nick Guletskii Date: Tue, 30 Apr 2019 21:39:02 +0300 Subject: [PATCH 08/19] Add zero-dim and zero-size array tests for index_array --- tests/python/gpu/test_operator_gpu.py | 36 ++++++++++++++++++++++++++ tests/python/unittest/test_operator.py | 36 ++++++++++++++++++++++++++ 2 files changed, 72 insertions(+) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index b147f8f6665c..4b537d1bd6c1 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -2143,6 +2143,30 @@ def test_index_array_default(): check_symbolic_forward(index_array, [input_array], [expected]) check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) +@with_seed() +def test_index_array_default_zero_dim(): + with mx.np_compat(active=True): + data = mx.symbol.Variable("data") + index_array = mx.sym.contrib.index_array(data) + + input_array = np.ones(()) + expected = np.zeros((0,)) + + check_symbolic_forward(index_array, [input_array], [expected]) + check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) + +@with_seed() +def test_index_array_default_zero_size(): + with mx.np_compat(active=True): + data = mx.symbol.Variable("data") + index_array = mx.sym.contrib.index_array(data) + + input_array = np.ones((0, 0, 0)) + expected = np.zeros((0, 0, 0, 3)) + + check_symbolic_forward(index_array, [input_array], [expected]) + check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) + @with_seed() def test_index_array_select_axes(): shape = (5, 7, 11, 13, 17, 19) @@ -2157,6 +2181,18 @@ def test_index_array_select_axes(): check_symbolic_forward(index_array, [input_array], [expected]) check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) +@with_seed() +def test_index_array_select_axes_zero_size(): + with mx.np_compat(active=True): + data = mx.symbol.Variable("data") + index_array = mx.sym.contrib.index_array(data, axes=(2, 1)) + + input_array = np.ones((0, 0, 0, 0)) + expected = np.zeros((0, 0, 2)) + + check_symbolic_forward(index_array, [input_array], [expected]) + check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) + # isolated execution bulking test function to be invoked with different env var settings def _test_bulking_in_process(seed, time_per_iteration): data_shape = (10,) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 66780b289730..dae4d9674c5a 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -8440,6 +8440,30 @@ def test_index_array_default(): check_symbolic_forward(index_array, [input_array], [expected]) check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) +@with_seed() +def test_index_array_default_zero_dim(): + with mx.np_compat(active=True): + data = mx.symbol.Variable("data") + index_array = mx.sym.contrib.index_array(data) + + input_array = np.ones(()) + expected = np.zeros((0,)) + + check_symbolic_forward(index_array, [input_array], [expected]) + check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) + +@with_seed() +def test_index_array_default_zero_size(): + with mx.np_compat(active=True): + data = mx.symbol.Variable("data") + index_array = mx.sym.contrib.index_array(data) + + input_array = np.ones((0, 0, 0)) + expected = np.zeros((0, 0, 0, 3)) + + check_symbolic_forward(index_array, [input_array], [expected]) + check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) + @with_seed() def test_index_array_select_axes(): shape = (5, 7, 11, 13, 17, 19) @@ -8454,6 +8478,18 @@ def test_index_array_select_axes(): check_symbolic_forward(index_array, [input_array], [expected]) check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) +@with_seed() +def test_index_array_select_axes_zero_size(): + with mx.np_compat(active=True): + data = mx.symbol.Variable("data") + index_array = mx.sym.contrib.index_array(data, axes=(2, 1)) + + input_array = np.ones((0, 0, 0, 0)) + expected = np.zeros((0, 0, 2)) + + check_symbolic_forward(index_array, [input_array], [expected]) + check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) + @with_seed() def test_scalar_tensor_creation(): assertRaises(MXNetError, mx.nd.zeros, shape=()) From 3eccb3cce01de5cda8caf0864d64dc4f3ae707a5 Mon Sep 17 00:00:00 2001 From: Nick Guletskii Date: Tue, 30 Apr 2019 23:53:01 +0300 Subject: [PATCH 09/19] Use mxnet::Tuple instead of TShape for the axes parameter --- src/operator/contrib/index_array-inl.h | 8 ++++---- src/operator/contrib/index_array.cc | 2 +- src/operator/contrib/index_array.cu | 2 +- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/operator/contrib/index_array-inl.h b/src/operator/contrib/index_array-inl.h index 05f4265cd63e..df9d6d9a7bf7 100644 --- a/src/operator/contrib/index_array-inl.h +++ b/src/operator/contrib/index_array-inl.h @@ -76,13 +76,13 @@ inline std::vector IndexArrayComputeIndexProducts(const TShape &inshape return index_products; } -inline void IndexArrayBuildSelectedAxesWorkspace(const TShape &axes, +inline void IndexArrayBuildSelectedAxesWorkspace(const mxnet::Tuple &axes, const std::vector &index_products, int64_t* workspace, const int ndim) { for (int i = 0; i < axes.ndim(); i++) { // Make sure that the axis is between 0 and ndim. - const dim_t axis = ((axes[i] % ndim) + ndim) % ndim; + const int axis = ((axes[i] % ndim) + ndim) % ndim; workspace[2 * i] = index_products[axis]; workspace[2 * i + 1] = index_products[axis + 1]; @@ -97,9 +97,9 @@ void IndexArrayForward(const nnvm::NodeAttrs &attrs, const std::vector &outputs); struct IndexArrayParam : public dmlc::Parameter { - dmlc::optional axes; + dmlc::optional> axes; DMLC_DECLARE_PARAMETER(IndexArrayParam) { - DMLC_DECLARE_FIELD(axes).set_default(dmlc::optional()) + DMLC_DECLARE_FIELD(axes).set_default(dmlc::optional>()) .describe("The axes to include in the index array. Supports negative values."); } }; // struct IndexArrayParam diff --git a/src/operator/contrib/index_array.cc b/src/operator/contrib/index_array.cc index 9c29cdcb4c68..ec2d0ea1c50d 100644 --- a/src/operator/contrib/index_array.cc +++ b/src/operator/contrib/index_array.cc @@ -46,7 +46,7 @@ void IndexArrayForward(const nnvm::NodeAttrs &attrs, using namespace mxnet_op; if (param.axes.has_value()) { - const TShape& axes = param.axes.value(); + const mxnet::Tuple& axes = param.axes.value(); const int naxes = axes.ndim(); std::vector index_products = IndexArrayComputeIndexProducts(inshape); diff --git a/src/operator/contrib/index_array.cu b/src/operator/contrib/index_array.cu index 0d39399cdab1..4c89c31c64e8 100644 --- a/src/operator/contrib/index_array.cu +++ b/src/operator/contrib/index_array.cu @@ -47,7 +47,7 @@ void IndexArrayForward(const nnvm::NodeAttrs &attrs, using namespace mxnet_op; if (param.axes.has_value()) { - const TShape& axes = param.axes.value(); + const mxnet::Tuple& axes = param.axes.value(); const int naxes = axes.ndim(); std::vector index_products = IndexArrayComputeIndexProducts(inshape); From a6b3cabb5da6dd348a71451ead2c0d6594b17131 Mon Sep 17 00:00:00 2001 From: Nick Guletskii Date: Thu, 2 May 2019 01:00:51 +0300 Subject: [PATCH 10/19] Fix incorrect array indexing in index_array Solves access violations when compiling with MSVC++ 14.0. --- src/operator/contrib/index_array-inl.h | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/operator/contrib/index_array-inl.h b/src/operator/contrib/index_array-inl.h index df9d6d9a7bf7..0ba6fc5bf50a 100644 --- a/src/operator/contrib/index_array-inl.h +++ b/src/operator/contrib/index_array-inl.h @@ -40,10 +40,10 @@ struct IndexArrayKernel { int64_t* out_data, const int n, const int64_t* workspace) { - for (int j = 0; j < n; j++) { - int64_t upper = workspace[2 * j]; - int64_t lower = workspace[2 * j + 1]; - KERNEL_ASSIGN(out_data[i * n + j], req, (i % upper) / lower); + for (ptrdiff_t j = 0; j < n; j++) { + int64_t upper = workspace[ptrdiff_t(2) * j]; + int64_t lower = workspace[ptrdiff_t(2) * j + ptrdiff_t(1)]; + KERNEL_ASSIGN(out_data[ptrdiff_t(i) * ptrdiff_t(n) + j], req, (i % upper) / lower); } } }; @@ -55,8 +55,8 @@ struct IndexArrayDefaultKernel { const int ndim, const dim_t* shape) { int64_t index = i; - for (int j = ndim - 1; j >= 0; j--) { - KERNEL_ASSIGN(out_data[i * ndim + j], req, index % shape[j]); + for (ptrdiff_t j = ndim - 1; j >= 0; j--) { + KERNEL_ASSIGN(out_data[ptrdiff_t(i) * ptrdiff_t(ndim) + j], req, index % shape[j]); index /= shape[j]; } } @@ -84,8 +84,8 @@ inline void IndexArrayBuildSelectedAxesWorkspace(const mxnet::Tuple &axes, // Make sure that the axis is between 0 and ndim. const int axis = ((axes[i] % ndim) + ndim) % ndim; - workspace[2 * i] = index_products[axis]; - workspace[2 * i + 1] = index_products[axis + 1]; + workspace[ptrdiff_t(2) * ptrdiff_t(i)] = index_products[axis]; + workspace[ptrdiff_t(2) * ptrdiff_t(i) + ptrdiff_t(1)] = index_products[axis + 1]; } } From 69bf4f61cc3499be34ac06ac8a4f55eee14199ff Mon Sep 17 00:00:00 2001 From: Nick Guletskii Date: Tue, 14 May 2019 00:23:09 +0300 Subject: [PATCH 11/19] Avoid copying the input shape array in the index_array shape function --- src/operator/contrib/index_array.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/contrib/index_array.cc b/src/operator/contrib/index_array.cc index ec2d0ea1c50d..8579fc95cd4f 100644 --- a/src/operator/contrib/index_array.cc +++ b/src/operator/contrib/index_array.cc @@ -134,7 +134,7 @@ Examples:: const IndexArrayParam ¶m = nnvm::get(attrs.parsed); CHECK_EQ(in_shape->size(), 1U); CHECK_EQ(out_shape->size(), 1U); - mxnet::TShape inshape = in_shape->at(index_array_enum::kIn); + const mxnet::TShape &inshape = (*in_shape)[index_array_enum::kIn]; mxnet::TShape oshape = mxnet::TShape(inshape.ndim() + 1, 0); for (int i = 0; i < inshape.ndim(); i++) { From be3979895ff09f20251f7ac8c7b3ae0ee11117ad Mon Sep 17 00:00:00 2001 From: Nick Guletskii Date: Tue, 14 May 2019 00:24:36 +0300 Subject: [PATCH 12/19] Add unknown shape handling to index_array --- src/operator/contrib/index_array.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/operator/contrib/index_array.cc b/src/operator/contrib/index_array.cc index 8579fc95cd4f..e587e2cc4ee2 100644 --- a/src/operator/contrib/index_array.cc +++ b/src/operator/contrib/index_array.cc @@ -135,6 +135,8 @@ Examples:: CHECK_EQ(in_shape->size(), 1U); CHECK_EQ(out_shape->size(), 1U); const mxnet::TShape &inshape = (*in_shape)[index_array_enum::kIn]; + if (!mxnet::ndim_is_known(inshape)) return false; + mxnet::TShape oshape = mxnet::TShape(inshape.ndim() + 1, 0); for (int i = 0; i < inshape.ndim(); i++) { @@ -147,7 +149,7 @@ Examples:: } out_shape->clear(); out_shape->push_back(oshape); - return true; + return shape_is_known(oshape); }) .set_attr("FInferType", [](const nnvm::NodeAttrs &attrs, std::vector *in_attrs, From 445f758a53a59bc6d5fff259b532ba3bded613e4 Mon Sep 17 00:00:00 2001 From: Nick Guletskii Date: Tue, 14 May 2019 00:24:53 +0300 Subject: [PATCH 13/19] Use SHAPE_ASSIGN_CHECK to assign the shape in index_array --- src/operator/contrib/index_array.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/operator/contrib/index_array.cc b/src/operator/contrib/index_array.cc index e587e2cc4ee2..bb1b25b3d3ea 100644 --- a/src/operator/contrib/index_array.cc +++ b/src/operator/contrib/index_array.cc @@ -147,8 +147,8 @@ Examples:: } else { oshape[inshape.ndim()] = inshape.ndim(); } - out_shape->clear(); - out_shape->push_back(oshape); + + SHAPE_ASSIGN_CHECK(*out_shape, 0, oshape); return shape_is_known(oshape); }) .set_attr("FInferType", [](const nnvm::NodeAttrs &attrs, From 972b05fc64eb5b5455b222f07f399b3bde55f343 Mon Sep 17 00:00:00 2001 From: Nick Guletskii Date: Tue, 14 May 2019 00:30:22 +0300 Subject: [PATCH 14/19] Remove the redundant index_array GPU tests from test_operator_gpu.py --- tests/python/gpu/test_operator_gpu.py | 62 --------------------------- 1 file changed, 62 deletions(-) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 4b537d1bd6c1..2a1583ed639e 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -2130,68 +2130,6 @@ def test_bilinear_sampler_versions(): if req_dict['grid'] is 'write': assert_almost_equal(exe.grad_dict['grid'].asnumpy(), exe_list[ref_idx].grad_dict['grid'].asnumpy(), rtol=1e-3, atol=1e-5) -@with_seed() -def test_index_array_default(): - for shape in [(10,), (7, 5, 29), (5, 7, 11, 13, 17, 19)]: - data = mx.symbol.Variable("data") - index_array = mx.sym.contrib.index_array(data) - - input_array = np.ones(shape) - mgrid = np.mgrid[tuple(slice(0, x) for x in shape)] - expected = np.stack(mgrid, axis=-1) - - check_symbolic_forward(index_array, [input_array], [expected]) - check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) - -@with_seed() -def test_index_array_default_zero_dim(): - with mx.np_compat(active=True): - data = mx.symbol.Variable("data") - index_array = mx.sym.contrib.index_array(data) - - input_array = np.ones(()) - expected = np.zeros((0,)) - - check_symbolic_forward(index_array, [input_array], [expected]) - check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) - -@with_seed() -def test_index_array_default_zero_size(): - with mx.np_compat(active=True): - data = mx.symbol.Variable("data") - index_array = mx.sym.contrib.index_array(data) - - input_array = np.ones((0, 0, 0)) - expected = np.zeros((0, 0, 0, 3)) - - check_symbolic_forward(index_array, [input_array], [expected]) - check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) - -@with_seed() -def test_index_array_select_axes(): - shape = (5, 7, 11, 13, 17, 19) - for axes in [(3,), (4, 1), (5, 1, 3), (-1,), (-5, -1, -3)]: - data = mx.symbol.Variable("data") - index_array = mx.sym.contrib.index_array(data, axes=axes) - - input_array = np.ones(shape) - mgrid = np.mgrid[tuple(slice(0, x) for x in shape)] - expected = np.stack(mgrid, axis=-1)[..., axes] - - check_symbolic_forward(index_array, [input_array], [expected]) - check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) - -@with_seed() -def test_index_array_select_axes_zero_size(): - with mx.np_compat(active=True): - data = mx.symbol.Variable("data") - index_array = mx.sym.contrib.index_array(data, axes=(2, 1)) - - input_array = np.ones((0, 0, 0, 0)) - expected = np.zeros((0, 0, 2)) - - check_symbolic_forward(index_array, [input_array], [expected]) - check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) # isolated execution bulking test function to be invoked with different env var settings def _test_bulking_in_process(seed, time_per_iteration): From 3051d5e578d9089d37e819f37943b186371c4561 Mon Sep 17 00:00:00 2001 From: Nick Guletskii Date: Tue, 14 May 2019 00:36:41 +0300 Subject: [PATCH 15/19] Move the index_array tests into a single function (test_index_array) --- tests/python/unittest/test_operator.py | 98 ++++++++++++++------------ 1 file changed, 51 insertions(+), 47 deletions(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index dae4d9674c5a..1581cdd4d37a 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -8428,67 +8428,71 @@ def test_image_normalize(): check_numeric_gradient(img_norm_sym, [data_in_4d], atol=0.001) @with_seed() -def test_index_array_default(): - for shape in [(10,), (7, 5, 29), (5, 7, 11, 13, 17, 19)]: - data = mx.symbol.Variable("data") - index_array = mx.sym.contrib.index_array(data) +def test_index_array(): + def test_index_array_default(): + for shape in [(10,), (7, 5, 29), (5, 7, 11, 13, 17, 19)]: + data = mx.symbol.Variable("data") + index_array = mx.sym.contrib.index_array(data) - input_array = np.ones(shape) - mgrid = np.mgrid[tuple(slice(0, x) for x in shape)] - expected = np.stack(mgrid, axis=-1) + input_array = np.ones(shape) + mgrid = np.mgrid[tuple(slice(0, x) for x in shape)] + expected = np.stack(mgrid, axis=-1) - check_symbolic_forward(index_array, [input_array], [expected]) - check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) + check_symbolic_forward(index_array, [input_array], [expected]) + check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) -@with_seed() -def test_index_array_default_zero_dim(): - with mx.np_compat(active=True): - data = mx.symbol.Variable("data") - index_array = mx.sym.contrib.index_array(data) + def test_index_array_default_zero_dim(): + with mx.np_compat(active=True): + data = mx.symbol.Variable("data") + index_array = mx.sym.contrib.index_array(data) - input_array = np.ones(()) - expected = np.zeros((0,)) + input_array = np.ones(()) + expected = np.zeros((0,)) - check_symbolic_forward(index_array, [input_array], [expected]) - check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) + check_symbolic_forward(index_array, [input_array], [expected]) + check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) -@with_seed() -def test_index_array_default_zero_size(): - with mx.np_compat(active=True): - data = mx.symbol.Variable("data") - index_array = mx.sym.contrib.index_array(data) + def test_index_array_default_zero_size(): + with mx.np_compat(active=True): + data = mx.symbol.Variable("data") + index_array = mx.sym.contrib.index_array(data) - input_array = np.ones((0, 0, 0)) - expected = np.zeros((0, 0, 0, 3)) + input_array = np.ones((0, 0, 0)) + expected = np.zeros((0, 0, 0, 3)) - check_symbolic_forward(index_array, [input_array], [expected]) - check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) + check_symbolic_forward(index_array, [input_array], [expected]) + check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) -@with_seed() -def test_index_array_select_axes(): - shape = (5, 7, 11, 13, 17, 19) - for axes in [(3,), (4, 1), (5, 1, 3), (-1,), (-5, -1, -3)]: - data = mx.symbol.Variable("data") - index_array = mx.sym.contrib.index_array(data, axes=axes) + def test_index_array_select_axes(): + shape = (5, 7, 11, 13, 17, 19) + for axes in [(3,), (4, 1), (5, 1, 3), (-1,), (-5, -1, -3)]: + data = mx.symbol.Variable("data") + index_array = mx.sym.contrib.index_array(data, axes=axes) - input_array = np.ones(shape) - mgrid = np.mgrid[tuple(slice(0, x) for x in shape)] - expected = np.stack(mgrid, axis=-1)[..., axes] + input_array = np.ones(shape) + mgrid = np.mgrid[tuple(slice(0, x) for x in shape)] + expected = np.stack(mgrid, axis=-1)[..., axes] - check_symbolic_forward(index_array, [input_array], [expected]) - check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) + check_symbolic_forward(index_array, [input_array], [expected]) + check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) -@with_seed() -def test_index_array_select_axes_zero_size(): - with mx.np_compat(active=True): - data = mx.symbol.Variable("data") - index_array = mx.sym.contrib.index_array(data, axes=(2, 1)) + def test_index_array_select_axes_zero_size(): + with mx.np_compat(active=True): + data = mx.symbol.Variable("data") + index_array = mx.sym.contrib.index_array(data, axes=(2, 1)) + + input_array = np.ones((0, 0, 0, 0)) + expected = np.zeros((0, 0, 2)) - input_array = np.ones((0, 0, 0, 0)) - expected = np.zeros((0, 0, 2)) + check_symbolic_forward(index_array, [input_array], [expected]) + check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) + + test_index_array_default() + test_index_array_default_zero_dim() + test_index_array_default_zero_size() + test_index_array_select_axes() + test_index_array_select_axes_zero_size() - check_symbolic_forward(index_array, [input_array], [expected]) - check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) @with_seed() def test_scalar_tensor_creation(): From 3f8702af50b6d946139f7ef2a57b80ad7c336c76 Mon Sep 17 00:00:00 2001 From: Nick Guletskii Date: Tue, 14 May 2019 00:38:11 +0300 Subject: [PATCH 16/19] Use @mx.use_np_compat instead of mx.np_compat in index_array op tests --- tests/python/unittest/test_operator.py | 42 +++++++++++++------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 1581cdd4d37a..9435f42a547c 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -8441,27 +8441,27 @@ def test_index_array_default(): check_symbolic_forward(index_array, [input_array], [expected]) check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) + @mx.use_np_compat def test_index_array_default_zero_dim(): - with mx.np_compat(active=True): - data = mx.symbol.Variable("data") - index_array = mx.sym.contrib.index_array(data) + data = mx.symbol.Variable("data") + index_array = mx.sym.contrib.index_array(data) - input_array = np.ones(()) - expected = np.zeros((0,)) + input_array = np.ones(()) + expected = np.zeros((0,)) - check_symbolic_forward(index_array, [input_array], [expected]) - check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) + check_symbolic_forward(index_array, [input_array], [expected]) + check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) + @mx.use_np_compat def test_index_array_default_zero_size(): - with mx.np_compat(active=True): - data = mx.symbol.Variable("data") - index_array = mx.sym.contrib.index_array(data) + data = mx.symbol.Variable("data") + index_array = mx.sym.contrib.index_array(data) - input_array = np.ones((0, 0, 0)) - expected = np.zeros((0, 0, 0, 3)) + input_array = np.ones((0, 0, 0)) + expected = np.zeros((0, 0, 0, 3)) - check_symbolic_forward(index_array, [input_array], [expected]) - check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) + check_symbolic_forward(index_array, [input_array], [expected]) + check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) def test_index_array_select_axes(): shape = (5, 7, 11, 13, 17, 19) @@ -8476,16 +8476,16 @@ def test_index_array_select_axes(): check_symbolic_forward(index_array, [input_array], [expected]) check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) + @mx.use_np_compat def test_index_array_select_axes_zero_size(): - with mx.np_compat(active=True): - data = mx.symbol.Variable("data") - index_array = mx.sym.contrib.index_array(data, axes=(2, 1)) + data = mx.symbol.Variable("data") + index_array = mx.sym.contrib.index_array(data, axes=(2, 1)) - input_array = np.ones((0, 0, 0, 0)) - expected = np.zeros((0, 0, 2)) + input_array = np.ones((0, 0, 0, 0)) + expected = np.zeros((0, 0, 2)) - check_symbolic_forward(index_array, [input_array], [expected]) - check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) + check_symbolic_forward(index_array, [input_array], [expected]) + check_symbolic_backward(index_array, [input_array], [np.ones(expected.shape)], [np.zeros_like(input_array)]) test_index_array_default() test_index_array_default_zero_dim() From 496bad30a428531f6fd3fd380e52473dfa35f87a Mon Sep 17 00:00:00 2001 From: Nick Guletskii Date: Tue, 14 May 2019 09:36:29 +0300 Subject: [PATCH 17/19] Remove the use of template specialization for IndexArrayForward --- src/operator/contrib/index_array-inl.h | 7 ------- src/operator/contrib/index_array.cc | 13 ++++++------- src/operator/contrib/index_array.cu | 13 ++++++------- 3 files changed, 12 insertions(+), 21 deletions(-) diff --git a/src/operator/contrib/index_array-inl.h b/src/operator/contrib/index_array-inl.h index 0ba6fc5bf50a..e280d7661b7c 100644 --- a/src/operator/contrib/index_array-inl.h +++ b/src/operator/contrib/index_array-inl.h @@ -89,13 +89,6 @@ inline void IndexArrayBuildSelectedAxesWorkspace(const mxnet::Tuple &axes, } } -template -void IndexArrayForward(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs); - struct IndexArrayParam : public dmlc::Parameter { dmlc::optional> axes; DMLC_DECLARE_PARAMETER(IndexArrayParam) { diff --git a/src/operator/contrib/index_array.cc b/src/operator/contrib/index_array.cc index bb1b25b3d3ea..a70dee106314 100644 --- a/src/operator/contrib/index_array.cc +++ b/src/operator/contrib/index_array.cc @@ -23,12 +23,11 @@ namespace mxnet { namespace op { -template<> -void IndexArrayForward(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { +void IndexArrayForwardCPU(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { using namespace mshadow; CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 1U); @@ -159,7 +158,7 @@ Examples:: TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64); return out_attrs->at(0) != -1; }) -.set_attr("FCompute", IndexArrayForward) +.set_attr("FCompute", IndexArrayForwardCPU) .set_attr("FGradient", MakeZeroGradNodes) .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; diff --git a/src/operator/contrib/index_array.cu b/src/operator/contrib/index_array.cu index 4c89c31c64e8..ddba6a87309a 100644 --- a/src/operator/contrib/index_array.cu +++ b/src/operator/contrib/index_array.cu @@ -24,12 +24,11 @@ namespace op { using namespace mshadow::cuda; -template<> -void IndexArrayForward(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { +void IndexArrayForwardGPU(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { using namespace mshadow; CHECK_EQ(inputs.size(), 1U); CHECK_EQ(outputs.size(), 1U); @@ -80,7 +79,7 @@ void IndexArrayForward(const nnvm::NodeAttrs &attrs, } NNVM_REGISTER_OP(_contrib_index_array) -.set_attr("FCompute", IndexArrayForward); +.set_attr("FCompute", IndexArrayForwardGPU); } // namespace op } // namespace mxnet From 69a99693d493e0539e8ce59f62ed21686b5c49c1 Mon Sep 17 00:00:00 2001 From: Nick Guletskii Date: Wed, 22 May 2019 19:05:46 +0300 Subject: [PATCH 18/19] Add the index_array operator to the AMP symbol list --- python/mxnet/contrib/amp/lists/symbol.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/mxnet/contrib/amp/lists/symbol.py b/python/mxnet/contrib/amp/lists/symbol.py index 2f8b4f0f9a6a..9c99340ab75d 100644 --- a/python/mxnet/contrib/amp/lists/symbol.py +++ b/python/mxnet/contrib/amp/lists/symbol.py @@ -95,6 +95,7 @@ '_contrib_gradientmultiplier', '_contrib_group_adagrad_update', '_contrib_ifft', + '_contrib_index_array', '_contrib_index_copy', '_contrib_quadratic', '_contrib_quantize', From ffde9b3e9ff5d9f5febc4cb6e53f76b5f6423a93 Mon Sep 17 00:00:00 2001 From: Nick Guletskii Date: Wed, 22 May 2019 21:26:48 +0300 Subject: [PATCH 19/19] Retrigger CI