Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix boolean_mask for 0-size output (#15731)
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 authored and reminisce committed Aug 2, 2019
1 parent cf28b46 commit 87425d2
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 12 deletions.
6 changes: 1 addition & 5 deletions include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,7 @@ class NDArray {
/*!
* \brief set the correct shape of NDArray directly from the storage_shape of its own chunk.
*/
void SetShapeFromChunk() {
if (!(ptr_->storage_shape.ndim() == 1 && ptr_->storage_shape[0] == 0)) {
shape_ = ptr_->storage_shape;
}
}
void SetShapeFromChunk();
/*
* This indicates whether an array is a view of another array (created by
* reshape or slice). If an array is a view and the data is stored in
Expand Down
4 changes: 3 additions & 1 deletion src/imperative/imperative.cc
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,9 @@ std::vector<NDArray*> Imperative::Backward(
} else {
info.outputs.emplace_back(outputs[i]->shape(), outputs[i]->ctx(),
true, outputs[i]->dtype());
info.outputs.back() = static_cast<real_t>(1.0);
if (info.outputs.back().shape().Size() != 0) {
info.outputs.back() = static_cast<real_t>(1.0);
}
}
}

Expand Down
7 changes: 7 additions & 0 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ NDArray::NDArray(const NDArrayStorageType stype, const mxnet::TShape &shape, Con
dtype, aux_types, aux_shapes);
}

void NDArray::SetShapeFromChunk() {
if (Imperative::Get()->is_np_shape() ||
!(ptr_->storage_shape.ndim() == 1 && ptr_->storage_shape[0] == 0)) {
shape_ = ptr_->storage_shape;
}
}

struct ChunkMem {
Storage::Handle h;
std::vector<Storage::Handle> aux_h;
Expand Down
1 change: 1 addition & 0 deletions src/operator/contrib/boolean_mask.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ inline void BooleanMaskForward<cpu>(const nnvm::NodeAttrs& attrs,
// set the output shape forcefully
mxnet::TShape s = data.shape();
s[axis] = valid_num;

const_cast<NDArray &>(out).Init(s);
// do the copy
MSHADOW_TYPE_SWITCH(data.dtype(), DType, {
Expand Down
15 changes: 9 additions & 6 deletions src/operator/contrib/boolean_mask.cu
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ inline void BooleanMaskForward<gpu>(const nnvm::NodeAttrs& attrs,
Stream<gpu>::GetStream(s));
CUDA_CALL(cudaMemcpy(&valid_num, &prefix_sum[idx_size - 1], sizeof(int32_t),
cudaMemcpyDeviceToHost));
CHECK(valid_num > 0) << "boolean_mask behavior not defined when all masks are 0";
// Set the output shape forcefully
mxnet::TShape data_shape = data.shape();
data_shape[axis] = valid_num;
Expand All @@ -88,8 +87,10 @@ inline void BooleanMaskForward<gpu>(const nnvm::NodeAttrs& attrs,
size_t col_size = input_size / idx.shape()[0];
// Do the copy
MSHADOW_TYPE_SWITCH(out.dtype(), DType, {
mxnet_op::Kernel<BooleanMaskForwardKernel, gpu>::Launch(
s, input_size, out.data().dptr<DType>(), data.data().dptr<DType>(), prefix_sum, col_size);
if (valid_num > 0) {
mxnet_op::Kernel<BooleanMaskForwardKernel, gpu>::Launch(
s, input_size, out.data().dptr<DType>(), data.data().dptr<DType>(), prefix_sum, col_size);
}
});
}

Expand Down Expand Up @@ -143,9 +144,11 @@ inline void BooleanMaskBackward<gpu>(const nnvm::NodeAttrs& attrs,
size_t col_size = input_size / idx_size;
// Backward pass
MSHADOW_TYPE_SWITCH(igrad_data.dtype(), DType, {
mxnet_op::Kernel<BooleanMaskBackwardKernel, gpu>::Launch(
s, input_size, igrad_data.data().dptr<DType>(), req[0], ograd.data().dptr<DType>(),
prefix_sum, col_size);
if (input_size > 0) {
mxnet_op::Kernel<BooleanMaskBackwardKernel, gpu>::Launch(
s, input_size, igrad_data.data().dptr<DType>(), req[0], ograd.data().dptr<DType>(),
prefix_sum, col_size);
}
});
}

Expand Down
15 changes: 15 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5589,6 +5589,21 @@ def test_boolean_mask():
assert same(out.asnumpy(), expected)
assert same(data.grad.asnumpy(), expected_grad)

# test 0-size output
mx.set_np_shape(True)
data = mx.nd.array([[1, 2, 3],[4, 5, 6],[7, 8, 9]])
index = mx.nd.array([0, 0, 0])
data.attach_grad()
with mx.autograd.record():
out = mx.nd.contrib.boolean_mask(data, index)
out.backward()
data.grad.wait_to_read()
expected = np.zeros((0, 3))
expected_grad = np.array([[0, 0, 0], [0, 0, 0], [0, 0, 0]])
assert same(out.asnumpy(), expected)
assert same(data.grad.asnumpy(), expected_grad)
mx.set_np_shape(False)

# test gradient
shape = (100, 30)
a = mx.nd.random.randint(0, 100, shape=shape)
Expand Down

0 comments on commit 87425d2

Please sign in to comment.