From 84353fb1f2d959d3ed5455c3b0cef1708c8480f9 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 15 Nov 2018 08:46:47 -0500 Subject: [PATCH 01/10] Add NDArray with lazy shape, Add boolean mask operator --- include/mxnet/ndarray.h | 36 +++++++++- src/imperative/imperative.cc | 13 +++- src/imperative/imperative_utils.h | 18 +++-- src/operator/contrib/boolean_mask-inl.h | 91 +++++++++++++++++++++++++ src/operator/contrib/boolean_mask.cc | 59 ++++++++++++++++ 5 files changed, 206 insertions(+), 11 deletions(-) create mode 100644 src/operator/contrib/boolean_mask-inl.h create mode 100644 src/operator/contrib/boolean_mask.cc diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index e877d35dbb5b..1a162ae19286 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -103,7 +103,18 @@ class NDArray { bool delay_alloc = true, int dtype = mshadow::default_type_flag, std::vector aux_types = {}, std::vector aux_shapes = {}, TShape storage_shape = TShape(mshadow::Shape1(0))); - + /*! + * \brief constructs a new dynamic NDArray whose shape is unknown, + * hence the NDArray is inherently lazily created + * \param ctx context of NDArray + * \param dtype data type of this ndarray + */ + NDArray(Context ctx, int dtype = mshadow::default_type_flag) { + ptr_ = std::make_shared(TShape(mshadow::Shape1(0)), ctx, true, dtype); + dtype_ = dtype; + storage_type_ = kDefaultStorage; + entry_ = {nullptr, 0, 0}; + } /*! * \brief constructing a static NDArray that shares data with TBlob * Use with caution: allocate ONLY ONE NDArray for each TBlob, @@ -157,7 +168,20 @@ class NDArray { : ptr_(std::make_shared(stype, data, aux_data, dev_id)), shape_(shape), dtype_(data.type_flag_), storage_type_(stype), entry_({nullptr, 0, 0}) { } - + /*! + * \brief initialize the NDArray, assuming it is not assigned a meaningful shape before + * \param shape the shape of the NDArray + */ + void Init(const TShape &shape) { + ptr_->Init(shape, this->dtype_); + this->shape_ = shape; + } + /*! + * \brief set the correct shape of NDArray directly from the storage_shape of its own chunk. + */ + void SetShapeFromChunk() { + shape_ = ptr_->storage_shape; + } /* * This indicates whether an array is a view of another array (created by * reshape or slice). If an array is a view and the the data is stored in @@ -960,7 +984,13 @@ class NDArray { #endif } } - + /*! \brief initialize the shape and dtype for this Chunk, assuming it is not initialized before. */ + void Init(const TShape &shape, int dtype) { + auto size = shape.Size(); + storage_shape = shape; + shandle.size = size * mshadow::mshadow_sizeof(dtype); + this->CheckAndAlloc(); + } inline void CheckAndAlloc(const TShape &shape, const std::vector &aux_shapes, int dtype) { // calculate size, perform allocation diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc index 32ff8d338131..de2f3d3a4d23 100644 --- a/src/imperative/imperative.cc +++ b/src/imperative/imperative.cc @@ -106,8 +106,17 @@ OpStatePtr Imperative::Invoke( SetShapeType(ctx, attrs, inputs, outputs, &dispatch_mode); std::vector req; SetWriteInplaceReq(inputs, outputs, &req); - - return InvokeOp(ctx, attrs, inputs, outputs, req, dispatch_mode); + OpStatePtr ret = InvokeOp(ctx, attrs, inputs, outputs, req, dispatch_mode); + // the followinng loop is used for finding out the correct shape when some shapes are dynamic + for (size_t i = 0; i < outputs.size(); i++) { + if (outputs[i]->shape().ndim() == 0) { + // the WaitToRead overhead here does not seem to be avoidable + outputs[i]->WaitToRead(); + outputs[i]->SetShapeFromChunk(); + } + CHECK(outputs[i]->shape().ndim()); + } + return ret; } void Imperative::MarkVariables( diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 9c86843ca7af..0a45b6b5b1b1 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -117,11 +117,14 @@ inline void SetShapeType(const Context& ctx, for (auto& i : outputs) { out_shapes.push_back(i->shape()); } - CHECK(infershape.count(attrs.op)) - << "Operator " << attrs.op->name << " is missing FInferShape attribute"; - CHECK(infershape[attrs.op](attrs, &in_shapes, &out_shapes)); - CHECK_EQ(out_shapes.size(), outputs.size()); - + const bool dynamic_shape_exists = [&]() { + if (!infershape.count(attrs.op)) { + return true; + } + CHECK(infershape[attrs.op](attrs, &in_shapes, &out_shapes)); + CHECK_EQ(out_shapes.size(), outputs.size()); + return false; + }(); // infer type std::vector& in_types = ret->arg_types; in_types.clear(); @@ -178,7 +181,10 @@ inline void SetShapeType(const Context& ctx, for (size_t i = 0; i < outputs.size(); ++i) { NDArrayStorageType storage_type = static_cast(out_storage_types[i]); if (outputs[i]->is_none()) { - if (storage_type == kDefaultStorage) { + if (dynamic_shape_exists) { + // once there is dynamic shape somewhere, we could not pre-determine the shape. + *outputs[i] = NDArray(ctx, out_types[i]); + } else if (storage_type == kDefaultStorage) { *outputs[i] = NDArray(out_shapes[i], ctx, true, out_types[i]); } else { *outputs[i] = NDArray(storage_type, out_shapes[i], ctx, true, out_types[i]); diff --git a/src/operator/contrib/boolean_mask-inl.h b/src/operator/contrib/boolean_mask-inl.h new file mode 100644 index 000000000000..d53fc22511bd --- /dev/null +++ b/src/operator/contrib/boolean_mask-inl.h @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * Copyright (c) 2018 by Contributors + * \file boolean_mask-inl.h +*/ + +#ifndef MXNET_OPERATOR_CONTRIB_BOOLEAN_MASK_INL_H_ +#define MXNET_OPERATOR_CONTRIB_BOOLEAN_MASK_INL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include "../operator_common.h" +#include "../mxnet_op.h" +#include "../mshadow_op.h" + +namespace mxnet { +namespace op { + +struct BooleanMaskParam : public dmlc::Parameter { + int axis; + DMLC_DECLARE_PARAMETER(BooleanMaskParam) { + DMLC_DECLARE_FIELD(axis).set_default(0) + .describe("An integer that represents the axis in NDArray to mask from."); + } +}; + +template +inline void BooleanMaskForward(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + const BooleanMaskParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(param.axis, 0); + CHECK_EQ(inputs[0].shape()[param.axis], inputs[1].shape()[0]); + CHECK_EQ(inputs[1].shape().ndim(), 1U); + size_t valid_num = 0; + const TBlob &idx = inputs[1].data(); + MSHADOW_TYPE_SWITCH(inputs[1].dtype(), DType, { + for (int i = 0; i < inputs[1].shape()[0]; i++) { + if (idx.dptr()[i]) + valid_num++; + } + }); + TShape s = inputs[0].shape(); + s[0] = valid_num; + const_cast(outputs[0]).Init(s); + size_t j = 0; + size_t ele_size = mshadow::mshadow_sizeof(inputs[0].dtype()); + MSHADOW_TYPE_SWITCH(inputs[1].dtype(), DType, { + for (int i = 0; i < inputs[1].shape()[0]; i++) { + if (idx.dptr()[i]) { + NDArray src = inputs[0].At(i); + NDArray dst = outputs[0].At(j); + CHECK(src.shape() == dst.shape()); + memcpy(dst.data().dptr_, src.data().dptr_, src.shape().Size() * ele_size); + j++; + } + } + }); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_CONTRIB_BOOLEAN_MASK_INL_H_ diff --git a/src/operator/contrib/boolean_mask.cc b/src/operator/contrib/boolean_mask.cc new file mode 100644 index 000000000000..52e868c93f10 --- /dev/null +++ b/src/operator/contrib/boolean_mask.cc @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * Copyright (c) 2018 by Contributors + * \file boolean_mask.cc +*/ + +#include "./boolean_mask-inl.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(BooleanMaskParam); + +bool BooleanMaskStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2); + CHECK_EQ(out_attrs->size(), 1); + for (size_t i = 0; i < out_attrs->size(); i++) + out_attrs->at(i) = kDefaultStorage; + *dispatch_mode = DispatchMode::kFComputeEx; + return true; +} + +NNVM_REGISTER_OP(_contrib_BooleanMask) +.describe(R"code( +)code" ADD_FILELINE) +.set_attr_parser(ParamParser) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FComputeEx", BooleanMaskForward) +.set_attr("FInferStorageType", BooleanMaskStorageType) +//.set_attr("FGradient", +// ElemwiseGradUseNone{"_backward_contrib_BooleanMask"}) +.add_argument("data", "NDArray-or-Symbol", "Data") +.add_argument("index", "NDArray-or-Symbol", "Mask") +.add_arguments(BooleanMaskParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet From c3aa226005f82db00fdfacaa76583f17c570b1e4 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 15 Nov 2018 09:00:38 -0500 Subject: [PATCH 02/10] Fix lints --- include/mxnet/ndarray.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 1a162ae19286..4ba13ca6498a 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -109,7 +109,7 @@ class NDArray { * \param ctx context of NDArray * \param dtype data type of this ndarray */ - NDArray(Context ctx, int dtype = mshadow::default_type_flag) { + explicit NDArray(Context ctx, int dtype = mshadow::default_type_flag) { ptr_ = std::make_shared(TShape(mshadow::Shape1(0)), ctx, true, dtype); dtype_ = dtype; storage_type_ = kDefaultStorage; @@ -984,7 +984,7 @@ class NDArray { #endif } } - /*! \brief initialize the shape and dtype for this Chunk, assuming it is not initialized before. */ + /*! \brief initialize the shape and dtype, assuming it is not initialized before. */ void Init(const TShape &shape, int dtype) { auto size = shape.Size(); storage_shape = shape; From 822901ec5e2d56b5d531da927a3e6de7ccf470cc Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 16 Nov 2018 10:45:36 -0500 Subject: [PATCH 03/10] Address comments and refactor forward pass --- src/imperative/imperative_utils.h | 13 +++--- src/operator/contrib/boolean_mask-inl.h | 55 ++++++++++++++----------- src/operator/contrib/boolean_mask.cc | 9 ++-- 3 files changed, 44 insertions(+), 33 deletions(-) diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 0a45b6b5b1b1..4b0d13167356 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -117,14 +117,13 @@ inline void SetShapeType(const Context& ctx, for (auto& i : outputs) { out_shapes.push_back(i->shape()); } - const bool dynamic_shape_exists = [&]() { - if (!infershape.count(attrs.op)) { - return true; - } + bool is_dynamic_shape_existing = false; + if (!infershape.count(attrs.op)) { + is_dynamic_shape_existing = true; + } else { CHECK(infershape[attrs.op](attrs, &in_shapes, &out_shapes)); CHECK_EQ(out_shapes.size(), outputs.size()); - return false; - }(); + } // infer type std::vector& in_types = ret->arg_types; in_types.clear(); @@ -181,7 +180,7 @@ inline void SetShapeType(const Context& ctx, for (size_t i = 0; i < outputs.size(); ++i) { NDArrayStorageType storage_type = static_cast(out_storage_types[i]); if (outputs[i]->is_none()) { - if (dynamic_shape_exists) { + if (is_dynamic_shape_existing) { // once there is dynamic shape somewhere, we could not pre-determine the shape. *outputs[i] = NDArray(ctx, out_types[i]); } else if (storage_type == kDefaultStorage) { diff --git a/src/operator/contrib/boolean_mask-inl.h b/src/operator/contrib/boolean_mask-inl.h index d53fc22511bd..e721a983779e 100644 --- a/src/operator/contrib/boolean_mask-inl.h +++ b/src/operator/contrib/boolean_mask-inl.h @@ -35,6 +35,7 @@ #include "../operator_common.h" #include "../mxnet_op.h" #include "../mshadow_op.h" +#include "../elemwise_op_common.h" namespace mxnet { namespace op { @@ -47,7 +48,6 @@ struct BooleanMaskParam : public dmlc::Parameter { } }; -template inline void BooleanMaskForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &inputs, @@ -56,32 +56,41 @@ inline void BooleanMaskForward(const nnvm::NodeAttrs& attrs, CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); const BooleanMaskParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(param.axis, 0); - CHECK_EQ(inputs[0].shape()[param.axis], inputs[1].shape()[0]); - CHECK_EQ(inputs[1].shape().ndim(), 1U); + const int axis = param.axis; + const NDArray &data = inputs[0]; + const NDArray &idx = inputs[1]; + const NDArray &out = outputs[0]; + CHECK_EQ(axis, 0) << "Not supported yet"; + CHECK_EQ(data.shape()[axis], idx.shape()[0]); + CHECK_EQ(idx.shape().ndim(), 1U); + // count the number of 1s in `idx`, so that we could know the output dimension size_t valid_num = 0; - const TBlob &idx = inputs[1].data(); - MSHADOW_TYPE_SWITCH(inputs[1].dtype(), DType, { - for (int i = 0; i < inputs[1].shape()[0]; i++) { - if (idx.dptr()[i]) - valid_num++; + MSHADOW_TYPE_SWITCH(idx.dtype(), DType, { + DType* idx_dptr = idx.data().dptr(); + int length = idx.shape()[0]; + for (int i = 0; i < length; i++) { + if (idx_dptr[i]) { + ++valid_num; + } } }); - TShape s = inputs[0].shape(); - s[0] = valid_num; - const_cast(outputs[0]).Init(s); - size_t j = 0; - size_t ele_size = mshadow::mshadow_sizeof(inputs[0].dtype()); - MSHADOW_TYPE_SWITCH(inputs[1].dtype(), DType, { - for (int i = 0; i < inputs[1].shape()[0]; i++) { - if (idx.dptr()[i]) { - NDArray src = inputs[0].At(i); - NDArray dst = outputs[0].At(j); - CHECK(src.shape() == dst.shape()); - memcpy(dst.data().dptr_, src.data().dptr_, src.shape().Size() * ele_size); - j++; + // set the output shape forcefully + TShape s = data.shape(); + s[axis] = valid_num; + const_cast(out).Init(s); + // do the copy + MSHADOW_TYPE_SWITCH(idx.dtype(), DType, { + DType* idx_dptr = idx.data().dptr(); + int length = idx.shape()[0]; + mshadow::Stream *stream = ctx.get_stream(); + for (int i = 0, j = 0; i < length; i++) { + if (idx_dptr[i]) { + NDArray src = data.At(i); + NDArray dst = out.At(j++); + CHECK(src.shape() == dst.shape()); + mxnet_op::copy(stream, dst.data(), src.data()); + } } - } }); } diff --git a/src/operator/contrib/boolean_mask.cc b/src/operator/contrib/boolean_mask.cc index 52e868c93f10..17edf431d1cb 100644 --- a/src/operator/contrib/boolean_mask.cc +++ b/src/operator/contrib/boolean_mask.cc @@ -41,16 +41,19 @@ bool BooleanMaskStorageType(const nnvm::NodeAttrs& attrs, return true; } +// TODO(@junrushao1994): update the docstring after the PR is almost done. NNVM_REGISTER_OP(_contrib_BooleanMask) .describe(R"code( +Experimental CPU-only support for boolean masking. +Given an NDArray x, and a 1-d NDArray index, +the operator produces an un-predeterminable shaped 2-d NDArray y, +which stands for the rows in x where the corresonding element in index is non-zero. )code" ADD_FILELINE) .set_attr_parser(ParamParser) .set_num_inputs(2) .set_num_outputs(1) -.set_attr("FComputeEx", BooleanMaskForward) +.set_attr("FComputeEx", BooleanMaskForward) .set_attr("FInferStorageType", BooleanMaskStorageType) -//.set_attr("FGradient", -// ElemwiseGradUseNone{"_backward_contrib_BooleanMask"}) .add_argument("data", "NDArray-or-Symbol", "Data") .add_argument("index", "NDArray-or-Symbol", "Mask") .add_arguments(BooleanMaskParam::__FIELDS__()); From b30cc96ff1cf7cdce005ede3ee3f8c4146b05403 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 16 Nov 2018 12:50:47 -0500 Subject: [PATCH 04/10] Add backward pass for boolean index --- src/operator/contrib/boolean_mask-inl.h | 29 +++++++++++++++++++++++++ src/operator/contrib/boolean_mask.cc | 8 +++++++ 2 files changed, 37 insertions(+) diff --git a/src/operator/contrib/boolean_mask-inl.h b/src/operator/contrib/boolean_mask-inl.h index e721a983779e..b3964eedbf27 100644 --- a/src/operator/contrib/boolean_mask-inl.h +++ b/src/operator/contrib/boolean_mask-inl.h @@ -32,6 +32,7 @@ #include #include #include +#include #include "../operator_common.h" #include "../mxnet_op.h" #include "../mshadow_op.h" @@ -94,6 +95,34 @@ inline void BooleanMaskForward(const nnvm::NodeAttrs& attrs, }); } +inline void BooleanMaskBackward(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + CHECK_EQ(inputs.size(), 3U); + CHECK_EQ(outputs.size(), 2U); + // inputs: {ograd, data, idx} + // outputs: {igrad_data, igrad_idx} + // TODO(@junrushao1994): how to declare no igrad w.r.t. index? + const NDArray& ograd = inputs[0]; + const NDArray& idx = inputs[2]; + const NDArray& igrad_data = outputs[0]; + MSHADOW_TYPE_SWITCH(idx.dtype(), DType, { + DType* idx_dptr = idx.data().dptr(); + int length = idx.shape()[0]; + mshadow::Stream *stream = ctx.get_stream(); + for (int i = 0, j = 0; i < length; i++) { + if (idx_dptr[i]) { + NDArray src = ograd.At(j++); + NDArray dst = igrad_data.At(i); + CHECK(src.shape() == dst.shape()); + mxnet_op::copy(stream, dst.data(), src.data()); + } + } + }); +} + } // namespace op } // namespace mxnet diff --git a/src/operator/contrib/boolean_mask.cc b/src/operator/contrib/boolean_mask.cc index 17edf431d1cb..889fb5bcc073 100644 --- a/src/operator/contrib/boolean_mask.cc +++ b/src/operator/contrib/boolean_mask.cc @@ -54,9 +54,17 @@ which stands for the rows in x where the corresonding element in index is non-ze .set_num_outputs(1) .set_attr("FComputeEx", BooleanMaskForward) .set_attr("FInferStorageType", BooleanMaskStorageType) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_contrib_BooleanMask"}) .add_argument("data", "NDArray-or-Symbol", "Data") .add_argument("index", "NDArray-or-Symbol", "Mask") .add_arguments(BooleanMaskParam::__FIELDS__()); +NNVM_REGISTER_OP(_backward_contrib_BooleanMask) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FComputeEx", BooleanMaskBackward) +.add_arguments(BooleanMaskParam::__FIELDS__()); + } // namespace op } // namespace mxnet From 1f615821d99cded70e2168b60121bb140bce9f03 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 16 Nov 2018 15:54:23 -0500 Subject: [PATCH 05/10] Fix tests.... --- src/imperative/imperative.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc index de2f3d3a4d23..a8d22abb560d 100644 --- a/src/imperative/imperative.cc +++ b/src/imperative/imperative.cc @@ -114,7 +114,6 @@ OpStatePtr Imperative::Invoke( outputs[i]->WaitToRead(); outputs[i]->SetShapeFromChunk(); } - CHECK(outputs[i]->shape().ndim()); } return ret; } From de471df071a234faf027160d7783b450a89d67f9 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 16 Nov 2018 23:43:01 -0500 Subject: [PATCH 06/10] Add unittests --- src/operator/contrib/boolean_mask-inl.h | 8 +++++--- src/operator/contrib/boolean_mask.cc | 20 +++++++++++++++++--- tests/python/unittest/test_operator.py | 16 ++++++++++++++++ 3 files changed, 38 insertions(+), 6 deletions(-) diff --git a/src/operator/contrib/boolean_mask-inl.h b/src/operator/contrib/boolean_mask-inl.h index b3964eedbf27..b84bd6a9e4d2 100644 --- a/src/operator/contrib/boolean_mask-inl.h +++ b/src/operator/contrib/boolean_mask-inl.h @@ -84,7 +84,7 @@ inline void BooleanMaskForward(const nnvm::NodeAttrs& attrs, DType* idx_dptr = idx.data().dptr(); int length = idx.shape()[0]; mshadow::Stream *stream = ctx.get_stream(); - for (int i = 0, j = 0; i < length; i++) { + for (int i = 0, j = 0; i < length; ++i) { if (idx_dptr[i]) { NDArray src = data.At(i); NDArray dst = out.At(j++); @@ -104,7 +104,6 @@ inline void BooleanMaskBackward(const nnvm::NodeAttrs& attrs, CHECK_EQ(outputs.size(), 2U); // inputs: {ograd, data, idx} // outputs: {igrad_data, igrad_idx} - // TODO(@junrushao1994): how to declare no igrad w.r.t. index? const NDArray& ograd = inputs[0]; const NDArray& idx = inputs[2]; const NDArray& igrad_data = outputs[0]; @@ -112,7 +111,10 @@ inline void BooleanMaskBackward(const nnvm::NodeAttrs& attrs, DType* idx_dptr = idx.data().dptr(); int length = idx.shape()[0]; mshadow::Stream *stream = ctx.get_stream(); - for (int i = 0, j = 0; i < length; i++) { + MSHADOW_TYPE_SWITCH(igrad_data.dtype(), igrad_data_DType, { + mxnet_op::Kernel::Launch(stream, igrad_data.data().Size(), igrad_data.data().dptr()); + }); + for (int i = 0, j = 0; i < length; ++i) { if (idx_dptr[i]) { NDArray src = ograd.At(j++); NDArray dst = igrad_data.At(i); diff --git a/src/operator/contrib/boolean_mask.cc b/src/operator/contrib/boolean_mask.cc index 889fb5bcc073..eed188d3d84f 100644 --- a/src/operator/contrib/boolean_mask.cc +++ b/src/operator/contrib/boolean_mask.cc @@ -41,8 +41,21 @@ bool BooleanMaskStorageType(const nnvm::NodeAttrs& attrs, return true; } +bool BooleanMaskBackStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 3); + CHECK_EQ(out_attrs->size(), 2); + for (size_t i = 0; i < out_attrs->size(); i++) + out_attrs->at(i) = kDefaultStorage; + *dispatch_mode = DispatchMode::kFComputeEx; + return true; +} + // TODO(@junrushao1994): update the docstring after the PR is almost done. -NNVM_REGISTER_OP(_contrib_BooleanMask) +NNVM_REGISTER_OP(_contrib_boolean_mask) .describe(R"code( Experimental CPU-only support for boolean masking. Given an NDArray x, and a 1-d NDArray index, @@ -54,15 +67,16 @@ which stands for the rows in x where the corresonding element in index is non-ze .set_num_outputs(1) .set_attr("FComputeEx", BooleanMaskForward) .set_attr("FInferStorageType", BooleanMaskStorageType) -.set_attr("FGradient", ElemwiseGradUseIn{"_backward_contrib_BooleanMask"}) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_contrib_boolean_mask"}) .add_argument("data", "NDArray-or-Symbol", "Data") .add_argument("index", "NDArray-or-Symbol", "Mask") .add_arguments(BooleanMaskParam::__FIELDS__()); -NNVM_REGISTER_OP(_backward_contrib_BooleanMask) +NNVM_REGISTER_OP(_backward_contrib_boolean_mask) .set_num_inputs(3) .set_num_outputs(2) .set_attr("TIsBackward", true) +.set_attr("FInferStorageType", BooleanMaskBackStorageType) .set_attr("FComputeEx", BooleanMaskBackward) .add_arguments(BooleanMaskParam::__FIELDS__()); diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 5fe9e3e048aa..7bd393584c1c 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -4792,6 +4792,22 @@ def test_index_copy(): assert same(t.grad.asnumpy(), t_grad.asnumpy()) assert same(index.grad.asnumpy(), index_grad.asnumpy()) + +@with_seed() +def test_boolean_mask(): + data = mx.nd.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]]) + index = mx.nd.array([0, 1, 0]) + data.attach_grad() + with mx.autograd.record(): + out = mx.nd.contrib.boolean_mask(data, index) + out.backward() + data.grad.wait_to_read() + expected = np.array([[4, 5, 6]]) + expected_grad = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]]) + assert same(out.asnumpy(), expected) + assert same(data.grad.asnumpy(), expected_grad) + + @with_seed() def test_div_sqrt_dim(): data_tmp = np.random.normal(0, 1, (5, 10, 8)) From e3f4c4022ddcc3a7dfa5022327b8e5e645cd3fcf Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Fri, 16 Nov 2018 23:51:27 -0500 Subject: [PATCH 07/10] Make lint happy --- src/operator/contrib/boolean_mask-inl.h | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/operator/contrib/boolean_mask-inl.h b/src/operator/contrib/boolean_mask-inl.h index b84bd6a9e4d2..87798a19c1a2 100644 --- a/src/operator/contrib/boolean_mask-inl.h +++ b/src/operator/contrib/boolean_mask-inl.h @@ -112,7 +112,10 @@ inline void BooleanMaskBackward(const nnvm::NodeAttrs& attrs, int length = idx.shape()[0]; mshadow::Stream *stream = ctx.get_stream(); MSHADOW_TYPE_SWITCH(igrad_data.dtype(), igrad_data_DType, { - mxnet_op::Kernel::Launch(stream, igrad_data.data().Size(), igrad_data.data().dptr()); + mxnet_op::Kernel::Launch( + stream, + igrad_data.data().Size(), + igrad_data.data().dptr()); }); for (int i = 0, j = 0; i < length; ++i) { if (idx_dptr[i]) { From c092998dfad15430766eef4e0dca61e2ffb71a04 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 18 Nov 2018 03:18:21 -0500 Subject: [PATCH 08/10] Address comments --- src/operator/contrib/boolean_mask-inl.h | 8 +++++--- src/operator/contrib/boolean_mask.cc | 16 ++++++++++++++-- tests/python/unittest/test_operator.py | 2 ++ 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/src/operator/contrib/boolean_mask-inl.h b/src/operator/contrib/boolean_mask-inl.h index 87798a19c1a2..5e3f5044b461 100644 --- a/src/operator/contrib/boolean_mask-inl.h +++ b/src/operator/contrib/boolean_mask-inl.h @@ -49,6 +49,7 @@ struct BooleanMaskParam : public dmlc::Parameter { } }; +template inline void BooleanMaskForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &inputs, @@ -83,7 +84,7 @@ inline void BooleanMaskForward(const nnvm::NodeAttrs& attrs, MSHADOW_TYPE_SWITCH(idx.dtype(), DType, { DType* idx_dptr = idx.data().dptr(); int length = idx.shape()[0]; - mshadow::Stream *stream = ctx.get_stream(); + mshadow::Stream *stream = ctx.get_stream(); for (int i = 0, j = 0; i < length; ++i) { if (idx_dptr[i]) { NDArray src = data.At(i); @@ -95,6 +96,7 @@ inline void BooleanMaskForward(const nnvm::NodeAttrs& attrs, }); } +template inline void BooleanMaskBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &inputs, @@ -110,9 +112,9 @@ inline void BooleanMaskBackward(const nnvm::NodeAttrs& attrs, MSHADOW_TYPE_SWITCH(idx.dtype(), DType, { DType* idx_dptr = idx.data().dptr(); int length = idx.shape()[0]; - mshadow::Stream *stream = ctx.get_stream(); + mshadow::Stream *stream = ctx.get_stream(); MSHADOW_TYPE_SWITCH(igrad_data.dtype(), igrad_data_DType, { - mxnet_op::Kernel::Launch( + mxnet_op::Kernel::Launch( stream, igrad_data.data().Size(), igrad_data.data().dptr()); diff --git a/src/operator/contrib/boolean_mask.cc b/src/operator/contrib/boolean_mask.cc index eed188d3d84f..8603ddbdd3a0 100644 --- a/src/operator/contrib/boolean_mask.cc +++ b/src/operator/contrib/boolean_mask.cc @@ -28,6 +28,17 @@ namespace op { DMLC_REGISTER_PARAMETER(BooleanMaskParam); + +bool BooleanMaskType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2); + CHECK_EQ(out_attrs->size(), 1); + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + return out_attrs->at(0) != -1; +} + bool BooleanMaskStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, @@ -65,7 +76,8 @@ which stands for the rows in x where the corresonding element in index is non-ze .set_attr_parser(ParamParser) .set_num_inputs(2) .set_num_outputs(1) -.set_attr("FComputeEx", BooleanMaskForward) +.set_attr("FInferType", BooleanMaskType) +.set_attr("FComputeEx", BooleanMaskForward) .set_attr("FInferStorageType", BooleanMaskStorageType) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_contrib_boolean_mask"}) .add_argument("data", "NDArray-or-Symbol", "Data") @@ -77,7 +89,7 @@ NNVM_REGISTER_OP(_backward_contrib_boolean_mask) .set_num_outputs(2) .set_attr("TIsBackward", true) .set_attr("FInferStorageType", BooleanMaskBackStorageType) -.set_attr("FComputeEx", BooleanMaskBackward) +.set_attr("FComputeEx", BooleanMaskBackward) .add_arguments(BooleanMaskParam::__FIELDS__()); } // namespace op diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 7bd393584c1c..e4465ad8cef9 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -4795,6 +4795,8 @@ def test_index_copy(): @with_seed() def test_boolean_mask(): + if default_context().device_type != 'cpu': + return data = mx.nd.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]]) index = mx.nd.array([0, 1, 0]) data.attach_grad() From 2e6a4e64b929efe44ba0ef199e7963b532d78076 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 18 Nov 2018 21:42:02 -0500 Subject: [PATCH 09/10] Address comments, and complete docstring --- src/operator/contrib/boolean_mask-inl.h | 10 ++++----- src/operator/contrib/boolean_mask.cc | 28 ++++++++++++++++++++----- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/src/operator/contrib/boolean_mask-inl.h b/src/operator/contrib/boolean_mask-inl.h index 5e3f5044b461..ac0681ba927b 100644 --- a/src/operator/contrib/boolean_mask-inl.h +++ b/src/operator/contrib/boolean_mask-inl.h @@ -35,6 +35,7 @@ #include #include "../operator_common.h" #include "../mxnet_op.h" +#include "../tensor/init_op.h" #include "../mshadow_op.h" #include "../elemwise_op_common.h" @@ -55,6 +56,8 @@ inline void BooleanMaskForward(const nnvm::NodeAttrs& attrs, const std::vector &inputs, const std::vector &req, const std::vector &outputs) { + // TODO(@junrushao1994): This implementation is a proof-of-concept, + // hence very slow actually. Performance should be improved in the future. CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1U); const BooleanMaskParam& param = nnvm::get(attrs.parsed); @@ -113,12 +116,7 @@ inline void BooleanMaskBackward(const nnvm::NodeAttrs& attrs, DType* idx_dptr = idx.data().dptr(); int length = idx.shape()[0]; mshadow::Stream *stream = ctx.get_stream(); - MSHADOW_TYPE_SWITCH(igrad_data.dtype(), igrad_data_DType, { - mxnet_op::Kernel::Launch( - stream, - igrad_data.data().Size(), - igrad_data.data().dptr()); - }); + Fill(stream, igrad_data.data(), req[0], 0); for (int i = 0, j = 0; i < length; ++i) { if (idx_dptr[i]) { NDArray src = ograd.At(j++); diff --git a/src/operator/contrib/boolean_mask.cc b/src/operator/contrib/boolean_mask.cc index 8603ddbdd3a0..f826f236ebe4 100644 --- a/src/operator/contrib/boolean_mask.cc +++ b/src/operator/contrib/boolean_mask.cc @@ -46,8 +46,12 @@ bool BooleanMaskStorageType(const nnvm::NodeAttrs& attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 2); CHECK_EQ(out_attrs->size(), 1); - for (size_t i = 0; i < out_attrs->size(); i++) - out_attrs->at(i) = kDefaultStorage; + for (int &attr : *in_attrs) { + CHECK_EQ(attr, kDefaultStorage) << "Only default storage is supported"; + } + for (int &attr : *out_attrs) { + attr = kDefaultStorage; + } *dispatch_mode = DispatchMode::kFComputeEx; return true; } @@ -59,19 +63,33 @@ bool BooleanMaskBackStorageType(const nnvm::NodeAttrs& attrs, std::vector *out_attrs) { CHECK_EQ(in_attrs->size(), 3); CHECK_EQ(out_attrs->size(), 2); + for (int &attr : *in_attrs) { + CHECK_EQ(attr, kDefaultStorage) << "Only default storage is supported"; + } + for (int &attr : *out_attrs) { + attr = kDefaultStorage; + } for (size_t i = 0; i < out_attrs->size(); i++) out_attrs->at(i) = kDefaultStorage; *dispatch_mode = DispatchMode::kFComputeEx; return true; } -// TODO(@junrushao1994): update the docstring after the PR is almost done. NNVM_REGISTER_OP(_contrib_boolean_mask) .describe(R"code( Experimental CPU-only support for boolean masking. -Given an NDArray x, and a 1-d NDArray index, -the operator produces an un-predeterminable shaped 2-d NDArray y, +Given an n-d NDArray x, and a 1-d NDArray index, +the operator produces an un-predeterminable shaped n-d NDArray y, which stands for the rows in x where the corresonding element in index is non-zero. + +>>> x = mx.nd.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]]) +>>> index = mx.nd.array([0, 1, 0]) +>>> y = mx.nd.contrib.boolean_mask(data, index) +>>> y + +[[4. 5. 6.]] + + )code" ADD_FILELINE) .set_attr_parser(ParamParser) .set_num_inputs(2) From f6340446cee85e75d08ce815cac77b6bf261d00d Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 18 Nov 2018 22:22:28 -0500 Subject: [PATCH 10/10] Update docstring --- src/operator/contrib/boolean_mask.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/operator/contrib/boolean_mask.cc b/src/operator/contrib/boolean_mask.cc index f826f236ebe4..2dcafb6b9494 100644 --- a/src/operator/contrib/boolean_mask.cc +++ b/src/operator/contrib/boolean_mask.cc @@ -78,14 +78,14 @@ bool BooleanMaskBackStorageType(const nnvm::NodeAttrs& attrs, NNVM_REGISTER_OP(_contrib_boolean_mask) .describe(R"code( Experimental CPU-only support for boolean masking. -Given an n-d NDArray x, and a 1-d NDArray index, -the operator produces an un-predeterminable shaped n-d NDArray y, +Given an n-d NDArray data, and a 1-d NDArray index, +the operator produces an un-predeterminable shaped n-d NDArray out, which stands for the rows in x where the corresonding element in index is non-zero. ->>> x = mx.nd.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]]) +>>> data = mx.nd.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]]) >>> index = mx.nd.array([0, 1, 0]) ->>> y = mx.nd.contrib.boolean_mask(data, index) ->>> y +>>> out = mx.nd.contrib.boolean_mask(data, index) +>>> out [[4. 5. 6.]]