Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1414]fix gradient bug of boolean mask #15175

Merged
merged 1 commit into from
Jun 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/operator/contrib/boolean_mask-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ struct BooleanMaskBackwardKernel {
template<typename DType>
static void MSHADOW_XINLINE Map(int i,
DType* igrad,
const OpReqType req,
const DType* ograd,
const int32_t* idx,
const size_t col_size) {
Expand All @@ -79,7 +80,13 @@ struct BooleanMaskBackwardKernel {
int32_t prev = (row_id == 0) ? 0 : idx[row_id - 1];
int32_t curr = idx[row_id];
if (prev != curr) {
igrad[i] = ograd[prev * col_size + col_id];
if (req == kAddTo)
igrad[i] += ograd[prev * col_size + col_id];
else
igrad[i] = ograd[prev * col_size + col_id];
} else {
if (req == kWriteTo || req == kWriteInplace)
igrad[i] = 0;
}
}
};
Expand Down
20 changes: 16 additions & 4 deletions src/operator/contrib/boolean_mask.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,12 @@ struct BooleanMaskForwardCPUKernel {
}
};

struct BooleanMaskBackwardCPUKernel {

struct BooleanMaskBackwardCPUWriteKernel {
template<typename DType>
static void Map(int i,
DType* igrad,
const OpReqType /*req*/,
const DType* ograd,
const int32_t* idx,
const size_t col_size) {
Expand All @@ -102,6 +104,8 @@ struct BooleanMaskBackwardCPUKernel {
int32_t curr = idx[i];
if (prev != curr) {
std::memcpy(igrad + i * col_size, ograd + prev * col_size, col_size * sizeof(DType));
} else {
std::memset(igrad + i * col_size, 0, col_size * sizeof(DType));
}
}
};
Expand All @@ -114,6 +118,7 @@ inline void BooleanMaskForward<cpu>(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray> &outputs) {
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
CHECK(req[0] == kWriteTo || req[0] == kWriteInplace);
const BooleanMaskParam& param = nnvm::get<BooleanMaskParam>(attrs.parsed);
const int axis = param.axis;
const NDArray &data = inputs[0];
Expand Down Expand Up @@ -158,6 +163,7 @@ inline void BooleanMaskBackward<cpu>(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray> &outputs) {
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 2U);
if (req[0] == kNullOp) return;
// inputs: {ograd, data, idx}
// outputs: {igrad_data, igrad_idx}
const NDArray& ograd = inputs[0];
Expand All @@ -175,9 +181,15 @@ inline void BooleanMaskBackward<cpu>(const nnvm::NodeAttrs& attrs,
prefix_sum[i] += (idx_dptr[i]) ? 1 : 0;
}
mshadow::Stream<cpu> *stream = ctx.get_stream<cpu>();
mxnet_op::Kernel<BooleanMaskBackwardCPUKernel, cpu>::Launch(
stream, idx_size, igrad_data.data().dptr<DType>(), ograd.data().dptr<DType>(),
prefix_sum.data(), col_size);
if (req[0] == kAddTo) {
mxnet_op::Kernel<BooleanMaskBackwardKernel, cpu>::Launch(
stream, idx_size, igrad_data.data().dptr<DType>(), req[0],
ograd.data().dptr<DType>(), prefix_sum.data(), col_size);
} else {
mxnet_op::Kernel<BooleanMaskBackwardCPUWriteKernel, cpu>::Launch(
stream, idx_size, igrad_data.data().dptr<DType>(), req[0],
ograd.data().dptr<DType>(), prefix_sum.data(), col_size);
}
});
});
}
Expand Down
4 changes: 3 additions & 1 deletion src/operator/contrib/boolean_mask.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ inline void BooleanMaskForward<gpu>(const nnvm::NodeAttrs& attrs,
using namespace mshadow;
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
CHECK(req[0] == kWriteTo || req[0] == kWriteInplace);
const BooleanMaskParam& param = nnvm::get<BooleanMaskParam>(attrs.parsed);
const int axis = param.axis;
const NDArray &data = inputs[0];
Expand Down Expand Up @@ -101,6 +102,7 @@ inline void BooleanMaskBackward<gpu>(const nnvm::NodeAttrs& attrs,
using namespace mshadow;
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 2U);
if (req[0] == kNullOp) return;
// inputs: {ograd, data, idx}
// outputs: {igrad_data, igrad_idx}
const NDArray& ograd = inputs[0];
Expand Down Expand Up @@ -142,7 +144,7 @@ inline void BooleanMaskBackward<gpu>(const nnvm::NodeAttrs& attrs,
// Backward pass
MSHADOW_TYPE_SWITCH(igrad_data.dtype(), DType, {
mxnet_op::Kernel<BooleanMaskBackwardKernel, gpu>::Launch(
s, input_size, igrad_data.data().dptr<DType>(), ograd.data().dptr<DType>(),
s, input_size, igrad_data.data().dptr<DType>(), req[0], ograd.data().dptr<DType>(),
prefix_sum, col_size);
});
}
Expand Down
24 changes: 24 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5339,6 +5339,30 @@ def test_boolean_mask():
assert same(out.asnumpy(), expected)
assert same(data.grad.asnumpy(), expected_grad)

# test gradient
shape = (100, 30)
a = mx.nd.random.randint(0, 100, shape=shape)
a.attach_grad()
bi = mx.nd.random.randint(0, 100, shape=shape[0:1]) > 50
ci = mx.nd.random.randint(0, 100, shape=shape[0:1]) < 50
mx_grad = mx.nd.zeros_like(a)
mx.autograd.mark_variables([a], [mx_grad], grad_reqs='add')
T = 3
for _ in range(T):
with mx.autograd.record():
b = mx.nd.contrib.boolean_mask(a, bi)
c = mx.nd.contrib.boolean_mask(a, ci)
su = b.sum() + c.sum()
su.backward()
grad = (bi + ci).asnumpy().reshape((-1,) + (1,) * (len(shape)-1))
grad = np.tile(grad, (1,) + shape[1:])
# T times
grad *= T
assert_allclose(a.grad.asnumpy(), grad)
a_np = a.asnumpy()
assert same(b.asnumpy(), a_np[bi.asnumpy().astype('bool')])
assert same(c.asnumpy(), a_np[ci.asnumpy().astype('bool')])


@with_seed()
def test_div_sqrt_dim():
Expand Down