diff --git a/include/mxnet/ndarray.h b/include/mxnet/ndarray.h index 428245b56d0e..176aa0aaa197 100644 --- a/include/mxnet/ndarray.h +++ b/include/mxnet/ndarray.h @@ -190,11 +190,7 @@ class NDArray { /*! * \brief set the correct shape of NDArray directly from the storage_shape of its own chunk. */ - void SetShapeFromChunk() { - if (!(ptr_->storage_shape.ndim() == 1 && ptr_->storage_shape[0] == 0)) { - shape_ = ptr_->storage_shape; - } - } + void SetShapeFromChunk(); /* * This indicates whether an array is a view of another array (created by * reshape or slice). If an array is a view and the data is stored in diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc index e2c0c9d4c9d4..c00021c44d1d 100644 --- a/src/imperative/imperative.cc +++ b/src/imperative/imperative.cc @@ -313,7 +313,9 @@ std::vector Imperative::Backward( } else { info.outputs.emplace_back(outputs[i]->shape(), outputs[i]->ctx(), true, outputs[i]->dtype()); - info.outputs.back() = static_cast(1.0); + if (info.outputs.back().shape().Size() != 0) { + info.outputs.back() = static_cast(1.0); + } } } diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index bee8bef37b44..37c32c09cebb 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -96,6 +96,13 @@ NDArray::NDArray(const NDArrayStorageType stype, const mxnet::TShape &shape, Con dtype, aux_types, aux_shapes); } +void NDArray::SetShapeFromChunk() { + if (Imperative::Get()->is_np_shape() || + !(ptr_->storage_shape.ndim() == 1 && ptr_->storage_shape[0] == 0)) { + shape_ = ptr_->storage_shape; + } +} + struct ChunkMem { Storage::Handle h; std::vector aux_h; diff --git a/src/operator/contrib/boolean_mask.cc b/src/operator/contrib/boolean_mask.cc index 4d66e1ec0a69..f431d77f26ce 100644 --- a/src/operator/contrib/boolean_mask.cc +++ b/src/operator/contrib/boolean_mask.cc @@ -143,6 +143,7 @@ inline void BooleanMaskForward(const nnvm::NodeAttrs& attrs, // set the output shape forcefully mxnet::TShape s = data.shape(); s[axis] = valid_num; + const_cast(out).Init(s); // do the copy MSHADOW_TYPE_SWITCH(data.dtype(), DType, { diff --git a/src/operator/contrib/boolean_mask.cu b/src/operator/contrib/boolean_mask.cu index 47335bfd6b79..c4a06d25d70a 100644 --- a/src/operator/contrib/boolean_mask.cu +++ b/src/operator/contrib/boolean_mask.cu @@ -79,7 +79,6 @@ inline void BooleanMaskForward(const nnvm::NodeAttrs& attrs, Stream::GetStream(s)); CUDA_CALL(cudaMemcpy(&valid_num, &prefix_sum[idx_size - 1], sizeof(int32_t), cudaMemcpyDeviceToHost)); - CHECK(valid_num > 0) << "boolean_mask behavior not defined when all masks are 0"; // Set the output shape forcefully mxnet::TShape data_shape = data.shape(); data_shape[axis] = valid_num; @@ -88,8 +87,10 @@ inline void BooleanMaskForward(const nnvm::NodeAttrs& attrs, size_t col_size = input_size / idx.shape()[0]; // Do the copy MSHADOW_TYPE_SWITCH(out.dtype(), DType, { - mxnet_op::Kernel::Launch( - s, input_size, out.data().dptr(), data.data().dptr(), prefix_sum, col_size); + if (valid_num > 0) { + mxnet_op::Kernel::Launch( + s, input_size, out.data().dptr(), data.data().dptr(), prefix_sum, col_size); + } }); } @@ -143,9 +144,11 @@ inline void BooleanMaskBackward(const nnvm::NodeAttrs& attrs, size_t col_size = input_size / idx_size; // Backward pass MSHADOW_TYPE_SWITCH(igrad_data.dtype(), DType, { - mxnet_op::Kernel::Launch( - s, input_size, igrad_data.data().dptr(), req[0], ograd.data().dptr(), - prefix_sum, col_size); + if (input_size > 0) { + mxnet_op::Kernel::Launch( + s, input_size, igrad_data.data().dptr(), req[0], ograd.data().dptr(), + prefix_sum, col_size); + } }); } diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 8f1c2533c62c..72bf5864ff4b 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -5589,6 +5589,21 @@ def test_boolean_mask(): assert same(out.asnumpy(), expected) assert same(data.grad.asnumpy(), expected_grad) + # test 0-size output + mx.set_np_shape(True) + data = mx.nd.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]]) + index = mx.nd.array([0, 0, 0]) + data.attach_grad() + with mx.autograd.record(): + out = mx.nd.contrib.boolean_mask(data, index) + out.backward() + data.grad.wait_to_read() + expected = np.zeros((0, 3)) + expected_grad = np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0]]) + assert same(out.asnumpy(), expected) + assert same(data.grad.asnumpy(), expected_grad) + mx.set_np_shape(False) + # test gradient shape = (100, 30) a = mx.nd.random.randint(0, 100, shape=shape)