diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index e877d35dbb5b..4ba13ca6498a 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -103,7 +103,18 @@ class NDArray { bool delay_alloc = true, int dtype = mshadow::default_type_flag, std::vector aux_types = {}, std::vector aux_shapes = {}, TShape storage_shape = TShape(mshadow::Shape1(0))); - + /*! + * \brief constructs a new dynamic NDArray whose shape is unknown, + * hence the NDArray is inherently lazily created + * \param ctx context of NDArray + * \param dtype data type of this ndarray + */ + explicit NDArray(Context ctx, int dtype = mshadow::default_type_flag) { + ptr_ = std::make_shared(TShape(mshadow::Shape1(0)), ctx, true, dtype); + dtype_ = dtype; + storage_type_ = kDefaultStorage; + entry_ = {nullptr, 0, 0}; + } /*! * \brief constructing a static NDArray that shares data with TBlob * Use with caution: allocate ONLY ONE NDArray for each TBlob, @@ -157,7 +168,20 @@ class NDArray { : ptr_(std::make_shared(stype, data, aux_data, dev_id)), shape_(shape), dtype_(data.type_flag_), storage_type_(stype), entry_({nullptr, 0, 0}) { } - + /*! + * \brief initialize the NDArray, assuming it is not assigned a meaningful shape before + * \param shape the shape of the NDArray + */ + void Init(const TShape &shape) { + ptr_->Init(shape, this->dtype_); + this->shape_ = shape; + } + /*! + * \brief set the correct shape of NDArray directly from the storage_shape of its own chunk. + */ + void SetShapeFromChunk() { + shape_ = ptr_->storage_shape; + } /* * This indicates whether an array is a view of another array (created by * reshape or slice). If an array is a view and the the data is stored in @@ -960,7 +984,13 @@ class NDArray { #endif } } - + /*! \brief initialize the shape and dtype, assuming it is not initialized before. */ + void Init(const TShape &shape, int dtype) { + auto size = shape.Size(); + storage_shape = shape; + shandle.size = size * mshadow::mshadow_sizeof(dtype); + this->CheckAndAlloc(); + } inline void CheckAndAlloc(const TShape &shape, const std::vector &aux_shapes, int dtype) { // calculate size, perform allocation diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc index fcbc09cacfe5..a381b2384113 100644 --- a/src/imperative/imperative.cc +++ b/src/imperative/imperative.cc @@ -106,8 +106,16 @@ OpStatePtr Imperative::Invoke( SetShapeType(ctx, attrs, inputs, outputs, &dispatch_mode); std::vector req; SetWriteInplaceReq(inputs, outputs, &req); - - return InvokeOp(ctx, attrs, inputs, outputs, req, dispatch_mode); + OpStatePtr ret = InvokeOp(ctx, attrs, inputs, outputs, req, dispatch_mode); + // the followinng loop is used for finding out the correct shape when some shapes are dynamic + for (size_t i = 0; i < outputs.size(); i++) { + if (outputs[i]->shape().ndim() == 0) { + // the WaitToRead overhead here does not seem to be avoidable + outputs[i]->WaitToRead(); + outputs[i]->SetShapeFromChunk(); + } + } + return ret; } void Imperative::MarkVariables( diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 9c86843ca7af..4b0d13167356 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -117,11 +117,13 @@ inline void SetShapeType(const Context& ctx, for (auto& i : outputs) { out_shapes.push_back(i->shape()); } - CHECK(infershape.count(attrs.op)) - << "Operator " << attrs.op->name << " is missing FInferShape attribute"; - CHECK(infershape[attrs.op](attrs, &in_shapes, &out_shapes)); - CHECK_EQ(out_shapes.size(), outputs.size()); - + bool is_dynamic_shape_existing = false; + if (!infershape.count(attrs.op)) { + is_dynamic_shape_existing = true; + } else { + CHECK(infershape[attrs.op](attrs, &in_shapes, &out_shapes)); + CHECK_EQ(out_shapes.size(), outputs.size()); + } // infer type std::vector& in_types = ret->arg_types; in_types.clear(); @@ -178,7 +180,10 @@ inline void SetShapeType(const Context& ctx, for (size_t i = 0; i < outputs.size(); ++i) { NDArrayStorageType storage_type = static_cast(out_storage_types[i]); if (outputs[i]->is_none()) { - if (storage_type == kDefaultStorage) { + if (is_dynamic_shape_existing) { + // once there is dynamic shape somewhere, we could not pre-determine the shape. + *outputs[i] = NDArray(ctx, out_types[i]); + } else if (storage_type == kDefaultStorage) { *outputs[i] = NDArray(out_shapes[i], ctx, true, out_types[i]); } else { *outputs[i] = NDArray(storage_type, out_shapes[i], ctx, true, out_types[i]); diff --git a/src/operator/contrib/boolean_mask-inl.h b/src/operator/contrib/boolean_mask-inl.h new file mode 100644 index 000000000000..ac0681ba927b --- /dev/null +++ b/src/operator/contrib/boolean_mask-inl.h @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * Copyright (c) 2018 by Contributors + * \file boolean_mask-inl.h +*/ + +#ifndef MXNET_OPERATOR_CONTRIB_BOOLEAN_MASK_INL_H_ +#define MXNET_OPERATOR_CONTRIB_BOOLEAN_MASK_INL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "../operator_common.h" +#include "../mxnet_op.h" +#include "../tensor/init_op.h" +#include "../mshadow_op.h" +#include "../elemwise_op_common.h" + +namespace mxnet { +namespace op { + +struct BooleanMaskParam : public dmlc::Parameter { + int axis; + DMLC_DECLARE_PARAMETER(BooleanMaskParam) { + DMLC_DECLARE_FIELD(axis).set_default(0) + .describe("An integer that represents the axis in NDArray to mask from."); + } +}; + +template +inline void BooleanMaskForward(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + // TODO(@junrushao1994): This implementation is a proof-of-concept, + // hence very slow actually. Performance should be improved in the future. + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + const BooleanMaskParam& param = nnvm::get(attrs.parsed); + const int axis = param.axis; + const NDArray &data = inputs[0]; + const NDArray &idx = inputs[1]; + const NDArray &out = outputs[0]; + CHECK_EQ(axis, 0) << "Not supported yet"; + CHECK_EQ(data.shape()[axis], idx.shape()[0]); + CHECK_EQ(idx.shape().ndim(), 1U); + // count the number of 1s in `idx`, so that we could know the output dimension + size_t valid_num = 0; + MSHADOW_TYPE_SWITCH(idx.dtype(), DType, { + DType* idx_dptr = idx.data().dptr(); + int length = idx.shape()[0]; + for (int i = 0; i < length; i++) { + if (idx_dptr[i]) { + ++valid_num; + } + } + }); + // set the output shape forcefully + TShape s = data.shape(); + s[axis] = valid_num; + const_cast(out).Init(s); + // do the copy + MSHADOW_TYPE_SWITCH(idx.dtype(), DType, { + DType* idx_dptr = idx.data().dptr(); + int length = idx.shape()[0]; + mshadow::Stream *stream = ctx.get_stream(); + for (int i = 0, j = 0; i < length; ++i) { + if (idx_dptr[i]) { + NDArray src = data.At(i); + NDArray dst = out.At(j++); + CHECK(src.shape() == dst.shape()); + mxnet_op::copy(stream, dst.data(), src.data()); + } + } + }); +} + +template +inline void BooleanMaskBackward(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + CHECK_EQ(inputs.size(), 3U); + CHECK_EQ(outputs.size(), 2U); + // inputs: {ograd, data, idx} + // outputs: {igrad_data, igrad_idx} + const NDArray& ograd = inputs[0]; + const NDArray& idx = inputs[2]; + const NDArray& igrad_data = outputs[0]; + MSHADOW_TYPE_SWITCH(idx.dtype(), DType, { + DType* idx_dptr = idx.data().dptr(); + int length = idx.shape()[0]; + mshadow::Stream *stream = ctx.get_stream(); + Fill(stream, igrad_data.data(), req[0], 0); + for (int i = 0, j = 0; i < length; ++i) { + if (idx_dptr[i]) { + NDArray src = ograd.At(j++); + NDArray dst = igrad_data.At(i); + CHECK(src.shape() == dst.shape()); + mxnet_op::copy(stream, dst.data(), src.data()); + } + } + }); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_CONTRIB_BOOLEAN_MASK_INL_H_ diff --git a/src/operator/contrib/boolean_mask.cc b/src/operator/contrib/boolean_mask.cc new file mode 100644 index 000000000000..2dcafb6b9494 --- /dev/null +++ b/src/operator/contrib/boolean_mask.cc @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * Copyright (c) 2018 by Contributors + * \file boolean_mask.cc +*/ + +#include "./boolean_mask-inl.h" + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(BooleanMaskParam); + + +bool BooleanMaskType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2); + CHECK_EQ(out_attrs->size(), 1); + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + return out_attrs->at(0) != -1; +} + +bool BooleanMaskStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 2); + CHECK_EQ(out_attrs->size(), 1); + for (int &attr : *in_attrs) { + CHECK_EQ(attr, kDefaultStorage) << "Only default storage is supported"; + } + for (int &attr : *out_attrs) { + attr = kDefaultStorage; + } + *dispatch_mode = DispatchMode::kFComputeEx; + return true; +} + +bool BooleanMaskBackStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 3); + CHECK_EQ(out_attrs->size(), 2); + for (int &attr : *in_attrs) { + CHECK_EQ(attr, kDefaultStorage) << "Only default storage is supported"; + } + for (int &attr : *out_attrs) { + attr = kDefaultStorage; + } + for (size_t i = 0; i < out_attrs->size(); i++) + out_attrs->at(i) = kDefaultStorage; + *dispatch_mode = DispatchMode::kFComputeEx; + return true; +} + +NNVM_REGISTER_OP(_contrib_boolean_mask) +.describe(R"code( +Experimental CPU-only support for boolean masking. +Given an n-d NDArray data, and a 1-d NDArray index, +the operator produces an un-predeterminable shaped n-d NDArray out, +which stands for the rows in x where the corresonding element in index is non-zero. + +>>> data = mx.nd.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]]) +>>> index = mx.nd.array([0, 1, 0]) +>>> out = mx.nd.contrib.boolean_mask(data, index) +>>> out + +[[4. 5. 6.]] + + +)code" ADD_FILELINE) +.set_attr_parser(ParamParser) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FInferType", BooleanMaskType) +.set_attr("FComputeEx", BooleanMaskForward) +.set_attr("FInferStorageType", BooleanMaskStorageType) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_contrib_boolean_mask"}) +.add_argument("data", "NDArray-or-Symbol", "Data") +.add_argument("index", "NDArray-or-Symbol", "Mask") +.add_arguments(BooleanMaskParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_contrib_boolean_mask) +.set_num_inputs(3) +.set_num_outputs(2) +.set_attr("TIsBackward", true) +.set_attr("FInferStorageType", BooleanMaskBackStorageType) +.set_attr("FComputeEx", BooleanMaskBackward) +.add_arguments(BooleanMaskParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 1bf9ca0237ab..c2701977a521 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -4793,6 +4793,24 @@ def test_index_copy(): assert same(x.grad.asnumpy(), x_grad.asnumpy()) assert same(t.grad.asnumpy(), t_grad.asnumpy()) + +@with_seed() +def test_boolean_mask(): + if default_context().device_type != 'cpu': + return + data = mx.nd.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]]) + index = mx.nd.array([0, 1, 0]) + data.attach_grad() + with mx.autograd.record(): + out = mx.nd.contrib.boolean_mask(data, index) + out.backward() + data.grad.wait_to_read() + expected = np.array([[4, 5, 6]]) + expected_grad = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]]) + assert same(out.asnumpy(), expected) + assert same(data.grad.asnumpy(), expected_grad) + + @with_seed() def test_div_sqrt_dim(): data_tmp = np.random.normal(0, 1, (5, 10, 8))