From 82c34ee51052d07d967ed99f573beb22779456cf Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Fri, 6 Dec 2019 22:13:46 -0800 Subject: [PATCH] boolean_mask_assign with start_axis (#16886) --- src/common/utils.h | 17 +++-- src/operator/numpy/np_boolean_mask_assign.cc | 75 ++++++++++++++++---- src/operator/numpy/np_boolean_mask_assign.cu | 51 +++++++++---- tests/python/unittest/test_numpy_op.py | 58 ++++++++++++--- 4 files changed, 159 insertions(+), 42 deletions(-) diff --git a/src/common/utils.h b/src/common/utils.h index 0e3e35430652..9a9c686e73c9 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -435,6 +435,15 @@ inline std::string dev_type_string(const int dev_type) { return "unknown"; } +inline std::string attr_value_string(const nnvm::NodeAttrs& attrs, + const std::string& attr_name, + std::string default_val = "") { + if (attrs.dict.find(attr_name) == attrs.dict.end()) { + return default_val; + } + return attrs.dict.at(attr_name); +} + /*! \brief get string representation of the operator stypes */ inline std::string operator_stype_string(const nnvm::NodeAttrs& attrs, const int dev_mask, @@ -463,10 +472,10 @@ inline std::string operator_stype_string(const nnvm::NodeAttrs& attrs, /*! \brief get string representation of the operator */ inline std::string operator_string(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { std::string result = ""; std::vector in_stypes; std::vector out_stypes; diff --git a/src/operator/numpy/np_boolean_mask_assign.cc b/src/operator/numpy/np_boolean_mask_assign.cc index 2a5ae116e291..0df3609ecab1 100644 --- a/src/operator/numpy/np_boolean_mask_assign.cc +++ b/src/operator/numpy/np_boolean_mask_assign.cc @@ -22,6 +22,7 @@ * \brief CPU implementation of Boolean Mask Assign */ +#include "../../common/utils.h" #include "../contrib/boolean_mask-inl.h" namespace mxnet { @@ -88,6 +89,7 @@ struct BooleanAssignCPUKernel { const size_t idx_size, const size_t leading, const size_t middle, + const size_t valid_num, const size_t trailing, DType* tensor) { // binary search for the turning point @@ -95,7 +97,8 @@ struct BooleanAssignCPUKernel { // final answer is in mid for (size_t l = 0; l < leading; ++l) { for (size_t t = 0; t < trailing; ++t) { - data[(l * middle + mid) * trailing + t] = (scalar) ? tensor[0] : tensor[i]; + data[(l * middle + mid) * trailing + t] = + (scalar) ? tensor[0] : tensor[(l * valid_num + i) * trailing + t]; } } } @@ -106,19 +109,47 @@ bool BooleanAssignShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *out_attrs) { CHECK(in_attrs->size() == 2U || in_attrs->size() == 3U); CHECK_EQ(out_attrs->size(), 1U); + CHECK(shape_is_known(in_attrs->at(0)) && shape_is_known(in_attrs->at(1))) + << "shape of both input and mask should be known"; const TShape& dshape = in_attrs->at(0); + const TShape& mshape = in_attrs->at(1); + const int start_axis = std::stoi(common::attr_value_string(attrs, "start_axis", "0")); - // mask should have the same shape as the input - SHAPE_ASSIGN_CHECK(*in_attrs, 1, dshape); + for (int i = 0; i < mshape.ndim(); ++i) { + CHECK_EQ(dshape[i + start_axis], mshape[i]) + << "boolean index did not match indexed array along dimension " << i + start_axis + << "; dimension is " << dshape[i + start_axis] << " but corresponding boolean dimension is " + << mshape[i]; + } // check if output shape is the same as the input data SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape); // for tensor version, the tensor should have less than 1 dimension if (in_attrs->size() == 3U) { - CHECK_LE(in_attrs->at(2).ndim(), 1U) - << "boolean array indexing assignment requires a 0 or 1-dimensional input, input has " - << in_attrs->at(2).ndim() <<" dimensions"; + if (mshape.ndim() == dshape.ndim()) { + CHECK_LE(in_attrs->at(2).ndim(), 1U) + << "boolean array indexing assignment requires a 0 or 1-dimensional input, input has " + << in_attrs->at(2).ndim() <<" dimensions"; + } else { + const TShape& vshape = in_attrs->at(2); + if (vshape.Size() > 1) { + for (int i = 0; i < dshape.ndim(); ++i) { + if (i < start_axis) { + CHECK_EQ(dshape[i], vshape[i]) + << "shape mismatch of value with input at dimension " << i + << "; dimension is " << dshape[i] << " but corresponding value dimension is " + << vshape[i]; + } + if (i >= start_axis + mshape.ndim()) { + CHECK_EQ(dshape[i], vshape[i - mshape.ndim() + 1]) + << "shape mismatch of value with input at dimension " << i + << "; dimension is " << dshape[i] << " but corresponding value dimension is " + << vshape[i - mshape.ndim() + 1]; + } + } + } + } } return shape_is_known(out_attrs->at(0)); @@ -170,22 +201,26 @@ void NumpyBooleanAssignForwardCPU(const nnvm::NodeAttrs& attrs, Stream* s = ctx.get_stream(); const TBlob& data = inputs[0]; + const TShape& dshape = data.shape_; const TBlob& mask = inputs[1]; + const TShape& mshape = mask.shape_; + const int start_axis = std::stoi(common::attr_value_string(attrs, "start_axis", "0")); // Get valid_num size_t valid_num = 0; size_t mask_size = mask.shape_.Size(); std::vector prefix_sum(mask_size + 1, 0); - MSHADOW_TYPE_SWITCH(mask.type_flag_, MType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(mask.type_flag_, MType, { valid_num = GetValidNumCPU(mask.dptr(), prefix_sum.data(), mask_size); }); // If there's no True in mask, return directly if (valid_num == 0) return; if (inputs.size() == 3U) { + const TShape& vshape = inputs[2].shape_; if (inputs[2].shape_.Size() != 1) { // tensor case, check tensor size with the valid_num - CHECK_EQ(static_cast(valid_num), inputs[2].shape_.Size()) - << "boolean array indexing assignment cannot assign " << inputs[2].shape_.Size() + CHECK_EQ(static_cast(valid_num), vshape[start_axis]) + << "boolean array indexing assignment cannot assign " << vshape << " input values to the " << valid_num << " output values where the mask is true" << std::endl; } @@ -195,21 +230,29 @@ void NumpyBooleanAssignForwardCPU(const nnvm::NodeAttrs& attrs, size_t middle = mask_size; size_t trailing = 1U; + for (int i = 0; i < dshape.ndim(); ++i) { + if (i < start_axis) { + leading *= dshape[i]; + } + if (i >= start_axis + mshape.ndim()) { + trailing *= dshape[i]; + } + } + if (inputs.size() == 3U) { MSHADOW_TYPE_SWITCH(data.type_flag_, DType, { if (inputs[2].shape_.Size() == 1) { Kernel, cpu>::Launch( s, valid_num, data.dptr(), prefix_sum.data(), prefix_sum.size(), - leading, middle, trailing, inputs[2].dptr()); + leading, middle, valid_num, trailing, inputs[2].dptr()); } else { Kernel, cpu>::Launch( s, valid_num, data.dptr(), prefix_sum.data(), prefix_sum.size(), - leading, middle, trailing, inputs[2].dptr()); + leading, middle, valid_num, trailing, inputs[2].dptr()); } }); } else { - CHECK(attrs.dict.find("value") != attrs.dict.end()) - << "value needs be provided"; + CHECK(attrs.dict.find("value") != attrs.dict.end()) << "value needs be provided"; MSHADOW_TYPE_SWITCH(data.type_flag_, DType, { Kernel, cpu>::Launch( s, valid_num, data.dptr(), prefix_sum.data(), prefix_sum.size(), @@ -240,7 +283,8 @@ NNVM_REGISTER_OP(_npi_boolean_mask_assign_scalar) .set_attr("FGradient", MakeZeroGradNodes) .add_argument("data", "NDArray-or-Symbol", "input") .add_argument("mask", "NDArray-or-Symbol", "mask") -.add_argument("value", "float", "value to be assigned to masked positions"); +.add_argument("value", "float", "value to be assigned to masked positions") +.add_argument("start_axis", "int", "starting axis of boolean mask"); NNVM_REGISTER_OP(_npi_boolean_mask_assign_tensor) .describe(R"code(Tensor version of boolean assign)code" ADD_FILELINE) @@ -264,7 +308,8 @@ NNVM_REGISTER_OP(_npi_boolean_mask_assign_tensor) .set_attr("FGradient", MakeZeroGradNodes) .add_argument("data", "NDArray-or-Symbol", "input") .add_argument("mask", "NDArray-or-Symbol", "mask") -.add_argument("value", "NDArray-or-Symbol", "assignment"); +.add_argument("value", "NDArray-or-Symbol", "assignment") +.add_argument("start_axis", "int", "starting axis of boolean mask"); } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/np_boolean_mask_assign.cu b/src/operator/numpy/np_boolean_mask_assign.cu index 2ccc4ffcd98d..5f3c24c02f33 100644 --- a/src/operator/numpy/np_boolean_mask_assign.cu +++ b/src/operator/numpy/np_boolean_mask_assign.cu @@ -23,6 +23,7 @@ */ #include +#include "../../common/utils.h" #include "../contrib/boolean_mask-inl.h" namespace mxnet { @@ -70,13 +71,17 @@ struct BooleanAssignGPUKernel { const size_t idx_size, const size_t leading, const size_t middle, + const size_t valid_num, const size_t trailing, const DType val) { // binary search for the turning point - size_t m = i / trailing % middle; + size_t m = i / trailing % valid_num; + size_t l = i / trailing / valid_num; size_t mid = bin_search(idx, idx_size, m); // final answer is in mid - data[i + (mid - m) * trailing] = val; + // i = l * valid_num * trailing + m * trailing + t + // dst = l * middle * trailing + mid * trailing + t + data[i + (l * (middle - valid_num) + (mid - m)) * trailing] = val; } template @@ -86,13 +91,20 @@ struct BooleanAssignGPUKernel { const size_t idx_size, const size_t leading, const size_t middle, + const size_t valid_num, const size_t trailing, DType* tensor) { // binary search for the turning point - size_t m = i / trailing % middle; + size_t m = i / trailing % valid_num; + size_t l = i / trailing / valid_num; size_t mid = bin_search(idx, idx_size, m); + size_t dst = i + (l * (middle - valid_num) + (mid - m)) * trailing; // final answer is in mid - data[i + (mid - m) * trailing] = (scalar) ? tensor[0] : tensor[m]; + if (scalar) { + data[dst] = tensor[0]; + } else { + data[dst] = tensor[i]; + } } }; @@ -166,13 +178,17 @@ void NumpyBooleanAssignForwardGPU(const nnvm::NodeAttrs& attrs, Stream* s = ctx.get_stream(); const TBlob& data = inputs[0]; + const TShape& dshape = data.shape_; const TBlob& mask = inputs[1]; + const TShape& mshape = mask.shape_; + const int start_axis = std::stoi(common::attr_value_string(attrs, "start_axis", "0")); + // Get valid_num size_t mask_size = mask.shape_.Size(); size_t valid_num = 0; size_t* prefix_sum = nullptr; if (mask_size != 0) { - MSHADOW_TYPE_SWITCH(mask.type_flag_, MType, { + MSHADOW_TYPE_SWITCH_WITH_BOOL(mask.type_flag_, MType, { prefix_sum = GetValidNumGPU(ctx, mask.dptr(), mask_size); }); cudaStream_t stream = mshadow::Stream::GetStream(s); @@ -180,14 +196,16 @@ void NumpyBooleanAssignForwardGPU(const nnvm::NodeAttrs& attrs, cudaMemcpyDeviceToHost, stream)); CUDA_CALL(cudaStreamSynchronize(stream)); } + // If there's no True in mask, return directly if (valid_num == 0) return; if (inputs.size() == 3U) { + const TShape& vshape = inputs[2].shape_; if (inputs[2].shape_.Size() != 1) { // tensor case, check tensor size with the valid_num - CHECK_EQ(static_cast(valid_num), inputs[2].shape_.Size()) - << "boolean array indexing assignment cannot assign " << inputs[2].shape_.Size() + CHECK_EQ(static_cast(valid_num), vshape[start_axis]) + << "boolean array indexing assignment cannot assign " << vshape << " input values to the " << valid_num << " output values where the mask is true" << std::endl; } @@ -197,27 +215,36 @@ void NumpyBooleanAssignForwardGPU(const nnvm::NodeAttrs& attrs, size_t middle = mask_size; size_t trailing = 1U; + for (int i = 0; i < dshape.ndim(); ++i) { + if (i < start_axis) { + leading *= dshape[i]; + } + if (i >= start_axis + mshape.ndim()) { + trailing *= dshape[i]; + } + } + if (inputs.size() == 3U) { if (inputs[2].shape_.Size() == 1) { MSHADOW_TYPE_SWITCH(data.type_flag_, DType, { Kernel, gpu>::Launch( s, leading * valid_num * trailing, data.dptr(), prefix_sum, mask_size + 1, - leading, middle, trailing, inputs[2].dptr()); + leading, middle, valid_num, trailing, inputs[2].dptr()); }); } else { MSHADOW_TYPE_SWITCH(data.type_flag_, DType, { Kernel, gpu>::Launch( s, leading * valid_num * trailing, data.dptr(), prefix_sum, mask_size + 1, - leading, middle, trailing, inputs[2].dptr()); + leading, middle, valid_num, trailing, inputs[2].dptr()); }); } } else { - CHECK(attrs.dict.find("value") != attrs.dict.end()) - << "value is not provided"; + CHECK(attrs.dict.find("value") != attrs.dict.end()) << "value is not provided"; + double value = std::stod(attrs.dict.at("value")); MSHADOW_TYPE_SWITCH(data.type_flag_, DType, { Kernel, gpu>::Launch( s, leading * valid_num * trailing, data.dptr(), prefix_sum, mask_size + 1, - leading, middle, trailing, static_cast(std::stod(attrs.dict.at("value")))); + leading, middle, valid_num, trailing, static_cast(value)); }); } } diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index cbdf5b33654b..61aec6832d7e 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1077,34 +1077,70 @@ def gt_grad_batch_dot_numpy(lhs, rhs, ograd, transpose_a, transpose_b, lhs_req, @use_np def test_npi_boolean_assign(): class TestBooleanAssignScalar(HybridBlock): - def __init__(self, val): + def __init__(self, val, start_axis): super(TestBooleanAssignScalar, self).__init__() self._val = val + self._start_axis = start_axis def hybrid_forward(self, F, a, mask): - return F.np._internal.boolean_mask_assign_scalar(a, mask, self._val, out=a) + return F.np._internal.boolean_mask_assign_scalar(a, mask, self._val, start_axis=self._start_axis, out=a) class TestBooleanAssignTensor(HybridBlock): - def __init__(self): + def __init__(self, start_axis): super(TestBooleanAssignTensor, self).__init__() + self._start_axis = start_axis def hybrid_forward(self, F, a, mask, value): - return F.np._internal.boolean_mask_assign_tensor(a, mask, value, out=a) + return F.np._internal.boolean_mask_assign_tensor(a, mask, value, start_axis=self._start_axis, out=a) + + configs = [ + ((3, 4), (3, 4), 0), + ((3, 0), (3, 0), 0), + ((), (), 0), + ((2, 3, 4, 5), (2, 3), 0), + ((2, 3, 4, 5), (3, 4), 1), + ((2, 3, 4, 5), (4, 5), 2), + ] - shapes = [(3, 4), (3, 0), ()] for hybridize in [False]: - for shape in shapes: - test_data = np.random.uniform(size=shape) - mx_mask = np.around(np.random.uniform(size=shape)) + for config in configs: + print(config) + dshape, mshape, start_axis = config + test_data = np.random.uniform(size=dshape) + mx_mask = np.around(np.random.uniform(size=mshape)) valid_num = int(mx_mask.sum()) np_mask = mx_mask.asnumpy().astype(_np.bool) - for val in [42., np.array(42.), np.array([42.]), np.random.uniform(size=(valid_num,))]: - test_block = TestBooleanAssignScalar(val) if isinstance(val, float) else TestBooleanAssignTensor() + vshape = [] + for i in range(len(dshape)): + if i < start_axis: + vshape.append(dshape[i]) + elif i == start_axis: + vshape.append(valid_num) + elif i >= start_axis + len(mshape): + vshape.append(dshape[i]) + vshape = tuple(vshape) + for val in [42.0, np.array(42.), np.array([42.]), np.random.uniform(size=vshape)]: + test_block = TestBooleanAssignScalar(val, start_axis) if isinstance(val, float) else TestBooleanAssignTensor(start_axis) if hybridize: test_block.hybridize() np_data = test_data.asnumpy() mx_data = test_data.copy() - np_data[np_mask] = val + trailing_axis = len(np_data.shape) - len(np_mask.shape) - start_axis + if start_axis == 0: + if trailing_axis == 0: + np_data[np_mask] = val + elif trailing_axis == 1: + np_data[np_mask, :] = val + elif trailing_axis == 2: + np_data[np_mask, :, :] = val + elif start_axis == 1: + if trailing_axis == 0: + np_data[:, np_mask] = val + elif trailing_axis == 1: + np_data[:, np_mask, :] = val + elif start_axis == 2: + if trailing_axis == 0: + np_data[:, :, np_mask] = val mx_data = test_block(mx_data, mx_mask) if isinstance(val, float) else test_block(mx_data, mx_mask, val) assert_almost_equal(mx_data.asnumpy(), np_data, rtol=1e-3, atol=1e-5, use_broadcast=False)