Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix contrib.boolean_mask for 0-size output #15731

Merged
merged 1 commit into from
Aug 2, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,7 @@ class NDArray {
/*!
* \brief set the correct shape of NDArray directly from the storage_shape of its own chunk.
*/
void SetShapeFromChunk() {
if (!(ptr_->storage_shape.ndim() == 1 && ptr_->storage_shape[0] == 0)) {
shape_ = ptr_->storage_shape;
}
}
void SetShapeFromChunk();
/*
* This indicates whether an array is a view of another array (created by
* reshape or slice). If an array is a view and the data is stored in
Expand Down
4 changes: 3 additions & 1 deletion src/imperative/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,9 @@ std::vector<NDArray*> Imperative::Backward(
} else {
info.outputs.emplace_back(outputs[i]->shape(), outputs[i]->ctx(),
true, outputs[i]->dtype());
info.outputs.back() = static_cast<real_t>(1.0);
if (info.outputs.back().shape().Size() != 0) {
info.outputs.back() = static_cast<real_t>(1.0);
}
}
}

Expand Down
7 changes: 7 additions & 0 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ NDArray::NDArray(const NDArrayStorageType stype, const mxnet::TShape &shape, Con
dtype, aux_types, aux_shapes);
}

void NDArray::SetShapeFromChunk() {
if (Imperative::Get()->is_np_shape() ||
!(ptr_->storage_shape.ndim() == 1 && ptr_->storage_shape[0] == 0)) {
shape_ = ptr_->storage_shape;
}
}

struct ChunkMem {
Storage::Handle h;
std::vector<Storage::Handle> aux_h;
Expand Down
1 change: 1 addition & 0 deletions src/operator/contrib/boolean_mask.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ inline void BooleanMaskForward<cpu>(const nnvm::NodeAttrs& attrs,
// set the output shape forcefully
mxnet::TShape s = data.shape();
s[axis] = valid_num;

const_cast<NDArray &>(out).Init(s);
// do the copy
MSHADOW_TYPE_SWITCH(data.dtype(), DType, {
Expand Down
15 changes: 9 additions & 6 deletions src/operator/contrib/boolean_mask.cu
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ inline void BooleanMaskForward<gpu>(const nnvm::NodeAttrs& attrs,
Stream<gpu>::GetStream(s));
CUDA_CALL(cudaMemcpy(&valid_num, &prefix_sum[idx_size - 1], sizeof(int32_t),
cudaMemcpyDeviceToHost));
CHECK(valid_num > 0) << "boolean_mask behavior not defined when all masks are 0";
// Set the output shape forcefully
mxnet::TShape data_shape = data.shape();
data_shape[axis] = valid_num;
Expand All @@ -88,8 +87,10 @@ inline void BooleanMaskForward<gpu>(const nnvm::NodeAttrs& attrs,
size_t col_size = input_size / idx.shape()[0];
// Do the copy
MSHADOW_TYPE_SWITCH(out.dtype(), DType, {
mxnet_op::Kernel<BooleanMaskForwardKernel, gpu>::Launch(
s, input_size, out.data().dptr<DType>(), data.data().dptr<DType>(), prefix_sum, col_size);
if (valid_num > 0) {
mxnet_op::Kernel<BooleanMaskForwardKernel, gpu>::Launch(
s, input_size, out.data().dptr<DType>(), data.data().dptr<DType>(), prefix_sum, col_size);
}
});
}

Expand Down Expand Up @@ -143,9 +144,11 @@ inline void BooleanMaskBackward<gpu>(const nnvm::NodeAttrs& attrs,
size_t col_size = input_size / idx_size;
// Backward pass
MSHADOW_TYPE_SWITCH(igrad_data.dtype(), DType, {
mxnet_op::Kernel<BooleanMaskBackwardKernel, gpu>::Launch(
s, input_size, igrad_data.data().dptr<DType>(), req[0], ograd.data().dptr<DType>(),
prefix_sum, col_size);
if (input_size > 0) {
mxnet_op::Kernel<BooleanMaskBackwardKernel, gpu>::Launch(
s, input_size, igrad_data.data().dptr<DType>(), req[0], ograd.data().dptr<DType>(),
prefix_sum, col_size);
}
});
}

Expand Down
15 changes: 15 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5589,6 +5589,21 @@ def test_boolean_mask():
assert same(out.asnumpy(), expected)
assert same(data.grad.asnumpy(), expected_grad)

# test 0-size output
mx.set_np_shape(True)
data = mx.nd.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
index = mx.nd.array([0, 0, 0])
data.attach_grad()
with mx.autograd.record():
out = mx.nd.contrib.boolean_mask(data, index)
out.backward()
data.grad.wait_to_read()
expected = np.zeros((0, 3))
expected_grad = np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0]])
assert same(out.asnumpy(), expected)
assert same(data.grad.asnumpy(), expected_grad)
mx.set_np_shape(False)

# test gradient
shape = (100, 30)
a = mx.nd.random.randint(0, 100, shape=shape)
Expand Down