Skip to content

Commit

Permalink
fix graident of boolean mask (apache#15175)
Browse files Browse the repository at this point in the history
  • Loading branch information
wkcn authored and haohuw committed Jun 23, 2019
1 parent 773d2ab commit 094effc
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 6 deletions.
9 changes: 8 additions & 1 deletion src/operator/contrib/boolean_mask-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ struct BooleanMaskBackwardKernel {
template<typename DType>
static void MSHADOW_XINLINE Map(int i,
DType* igrad,
const OpReqType req,
const DType* ograd,
const int32_t* idx,
const size_t col_size) {
Expand All @@ -79,7 +80,13 @@ struct BooleanMaskBackwardKernel {
int32_t prev = (row_id == 0) ? 0 : idx[row_id - 1];
int32_t curr = idx[row_id];
if (prev != curr) {
igrad[i] = ograd[prev * col_size + col_id];
if (req == kAddTo)
igrad[i] += ograd[prev * col_size + col_id];
else
igrad[i] = ograd[prev * col_size + col_id];
} else {
if (req == kWriteTo || req == kWriteInplace)
igrad[i] = 0;
}
}
};
Expand Down
20 changes: 16 additions & 4 deletions src/operator/contrib/boolean_mask.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,12 @@ struct BooleanMaskForwardCPUKernel {
}
};

struct BooleanMaskBackwardCPUKernel {

struct BooleanMaskBackwardCPUWriteKernel {
template<typename DType>
static void Map(int i,
DType* igrad,
const OpReqType /*req*/,
const DType* ograd,
const int32_t* idx,
const size_t col_size) {
Expand All @@ -102,6 +104,8 @@ struct BooleanMaskBackwardCPUKernel {
int32_t curr = idx[i];
if (prev != curr) {
std::memcpy(igrad + i * col_size, ograd + prev * col_size, col_size * sizeof(DType));
} else {
std::memset(igrad + i * col_size, 0, col_size * sizeof(DType));
}
}
};
Expand All @@ -114,6 +118,7 @@ inline void BooleanMaskForward<cpu>(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray> &outputs) {
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
CHECK(req[0] == kWriteTo || req[0] == kWriteInplace);
const BooleanMaskParam& param = nnvm::get<BooleanMaskParam>(attrs.parsed);
const int axis = param.axis;
const NDArray &data = inputs[0];
Expand Down Expand Up @@ -158,6 +163,7 @@ inline void BooleanMaskBackward<cpu>(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray> &outputs) {
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 2U);
if (req[0] == kNullOp) return;
// inputs: {ograd, data, idx}
// outputs: {igrad_data, igrad_idx}
const NDArray& ograd = inputs[0];
Expand All @@ -175,9 +181,15 @@ inline void BooleanMaskBackward<cpu>(const nnvm::NodeAttrs& attrs,
prefix_sum[i] += (idx_dptr[i]) ? 1 : 0;
}
mshadow::Stream<cpu> *stream = ctx.get_stream<cpu>();
mxnet_op::Kernel<BooleanMaskBackwardCPUKernel, cpu>::Launch(
stream, idx_size, igrad_data.data().dptr<DType>(), ograd.data().dptr<DType>(),
prefix_sum.data(), col_size);
if (req[0] == kAddTo) {
mxnet_op::Kernel<BooleanMaskBackwardKernel, cpu>::Launch(
stream, idx_size, igrad_data.data().dptr<DType>(), req[0],
ograd.data().dptr<DType>(), prefix_sum.data(), col_size);
} else {
mxnet_op::Kernel<BooleanMaskBackwardCPUWriteKernel, cpu>::Launch(
stream, idx_size, igrad_data.data().dptr<DType>(), req[0],
ograd.data().dptr<DType>(), prefix_sum.data(), col_size);
}
});
});
}
Expand Down
4 changes: 3 additions & 1 deletion src/operator/contrib/boolean_mask.cu
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ inline void BooleanMaskForward<gpu>(const nnvm::NodeAttrs& attrs,
using namespace mshadow;
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
CHECK(req[0] == kWriteTo || req[0] == kWriteInplace);
const BooleanMaskParam& param = nnvm::get<BooleanMaskParam>(attrs.parsed);
const int axis = param.axis;
const NDArray &data = inputs[0];
Expand Down Expand Up @@ -101,6 +102,7 @@ inline void BooleanMaskBackward<gpu>(const nnvm::NodeAttrs& attrs,
using namespace mshadow;
CHECK_EQ(inputs.size(), 3U);
CHECK_EQ(outputs.size(), 2U);
if (req[0] == kNullOp) return;
// inputs: {ograd, data, idx}
// outputs: {igrad_data, igrad_idx}
const NDArray& ograd = inputs[0];
Expand Down Expand Up @@ -142,7 +144,7 @@ inline void BooleanMaskBackward<gpu>(const nnvm::NodeAttrs& attrs,
// Backward pass
MSHADOW_TYPE_SWITCH(igrad_data.dtype(), DType, {
mxnet_op::Kernel<BooleanMaskBackwardKernel, gpu>::Launch(
s, input_size, igrad_data.data().dptr<DType>(), ograd.data().dptr<DType>(),
s, input_size, igrad_data.data().dptr<DType>(), req[0], ograd.data().dptr<DType>(),
prefix_sum, col_size);
});
}
Expand Down
24 changes: 24 additions & 0 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5337,6 +5337,30 @@ def test_boolean_mask():
assert same(out.asnumpy(), expected)
assert same(data.grad.asnumpy(), expected_grad)

# test gradient
shape = (100, 30)
a = mx.nd.random.randint(0, 100, shape=shape)
a.attach_grad()
bi = mx.nd.random.randint(0, 100, shape=shape[0:1]) > 50
ci = mx.nd.random.randint(0, 100, shape=shape[0:1]) < 50
mx_grad = mx.nd.zeros_like(a)
mx.autograd.mark_variables([a], [mx_grad], grad_reqs='add')
T = 3
for _ in range(T):
with mx.autograd.record():
b = mx.nd.contrib.boolean_mask(a, bi)
c = mx.nd.contrib.boolean_mask(a, ci)
su = b.sum() + c.sum()
su.backward()
grad = (bi + ci).asnumpy().reshape((-1,) + (1,) * (len(shape)-1))
grad = np.tile(grad, (1,) + shape[1:])
# T times
grad *= T
assert_allclose(a.grad.asnumpy(), grad)
a_np = a.asnumpy()
assert same(b.asnumpy(), a_np[bi.asnumpy().astype('bool')])
assert same(c.asnumpy(), a_np[ci.asnumpy().astype('bool')])


@with_seed()
def test_div_sqrt_dim():
Expand Down

0 comments on commit 094effc

Please sign in to comment.