From 1c108001111b0d0267ec3c046812122da2123c7c Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Wed, 2 Oct 2019 22:23:27 +0000 Subject: [PATCH] boolean_mask_assign operator for future boolean indexing --- src/operator/numpy/np_boolean_mask_assign.cc | 270 +++++++++++++++++++ src/operator/numpy/np_boolean_mask_assign.cu | 213 +++++++++++++++ tests/python/unittest/test_numpy_op.py | 36 +++ 3 files changed, 519 insertions(+) create mode 100644 src/operator/numpy/np_boolean_mask_assign.cc create mode 100644 src/operator/numpy/np_boolean_mask_assign.cu diff --git a/src/operator/numpy/np_boolean_mask_assign.cc b/src/operator/numpy/np_boolean_mask_assign.cc new file mode 100644 index 000000000000..2a5ae116e291 --- /dev/null +++ b/src/operator/numpy/np_boolean_mask_assign.cc @@ -0,0 +1,270 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_boolean_assign.cc + * \brief CPU implementation of Boolean Mask Assign + */ + +#include "../contrib/boolean_mask-inl.h" + +namespace mxnet { +namespace op { + +template +struct BooleanAssignCPUKernel { + private: + static size_t bin_search(const size_t* idx, + const size_t idx_size, + const size_t i) { + size_t left = 0, right = idx_size, mid = (left + right) / 2; + while (left != right) { + if (idx[mid] == i + 1) { + if (idx[mid - 1] == i) { + mid -= 1; + break; + } else if (idx[mid - 1] == i + 1) { + right = mid; + mid = (left + right) / 2; + } + } else if (idx[mid] == i) { + if (idx[mid + 1] == i + 1) { + break; + } else { + left = mid; + mid = (left + right + 1) / 2; + } + } else if (idx[mid] < i + 1) { + left = mid; + mid = (left + right + 1) / 2; + } else if (idx[mid] > i + 1) { + right = mid; + mid = (left + right) / 2; + } + } + return mid; + } + + public: + template + static void Map(int i, + DType* data, + const size_t* idx, + const size_t idx_size, + const size_t leading, + const size_t middle, + const size_t trailing, + const DType val) { + // binary search for the turning point + size_t mid = bin_search(idx, idx_size, i); + // final answer is in mid + for (size_t l = 0; l < leading; ++l) { + for (size_t t = 0; t < trailing; ++t) { + data[(l * middle + mid) * trailing + t] = val; + } + } + } + + template + static void Map(int i, + DType* data, + const size_t* idx, + const size_t idx_size, + const size_t leading, + const size_t middle, + const size_t trailing, + DType* tensor) { + // binary search for the turning point + size_t mid = bin_search(idx, idx_size, i); + // final answer is in mid + for (size_t l = 0; l < leading; ++l) { + for (size_t t = 0; t < trailing; ++t) { + data[(l * middle + mid) * trailing + t] = (scalar) ? tensor[0] : tensor[i]; + } + } + } +}; + +bool BooleanAssignShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector *in_attrs, + mxnet::ShapeVector *out_attrs) { + CHECK(in_attrs->size() == 2U || in_attrs->size() == 3U); + CHECK_EQ(out_attrs->size(), 1U); + const TShape& dshape = in_attrs->at(0); + + // mask should have the same shape as the input + SHAPE_ASSIGN_CHECK(*in_attrs, 1, dshape); + + // check if output shape is the same as the input data + SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape); + + // for tensor version, the tensor should have less than 1 dimension + if (in_attrs->size() == 3U) { + CHECK_LE(in_attrs->at(2).ndim(), 1U) + << "boolean array indexing assignment requires a 0 or 1-dimensional input, input has " + << in_attrs->at(2).ndim() <<" dimensions"; + } + + return shape_is_known(out_attrs->at(0)); +} + +bool BooleanAssignType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK(in_attrs->size() == 2U || in_attrs->size() == 3U); + CHECK_EQ(out_attrs->size(), 1U); + + // input and output should always have the same type + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0)); + + if (in_attrs->size() == 3U) { + // if tensor version, the tensor should also have the same type as input + TYPE_ASSIGN_CHECK(*in_attrs, 2, in_attrs->at(0)); + TYPE_ASSIGN_CHECK(*in_attrs, 0, in_attrs->at(2)); + CHECK_NE(in_attrs->at(2), -1); + } + + return out_attrs->at(0) != -1 && in_attrs->at(0) != -1 && in_attrs->at(1) != -1; +} + +// calculate the number of valid (masked) values, also completing the prefix_sum vector +template +size_t GetValidNumCPU(const DType* idx, size_t* prefix_sum, const size_t idx_size) { + prefix_sum[0] = 0; + for (size_t i = 0; i < idx_size; i++) { + prefix_sum[i + 1] = prefix_sum[i] + ((idx[i]) ? 1 : 0); + } + return prefix_sum[idx_size]; +} + +void NumpyBooleanAssignForwardCPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + using namespace mxnet_op; + CHECK(inputs.size() == 2U || inputs.size() == 3U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + CHECK_EQ(req[0], kWriteInplace) + << "Only WriteInplace is supported for npi_boolean_assign"; + + Stream* s = ctx.get_stream(); + + const TBlob& data = inputs[0]; + const TBlob& mask = inputs[1]; + // Get valid_num + size_t valid_num = 0; + size_t mask_size = mask.shape_.Size(); + std::vector prefix_sum(mask_size + 1, 0); + MSHADOW_TYPE_SWITCH(mask.type_flag_, MType, { + valid_num = GetValidNumCPU(mask.dptr(), prefix_sum.data(), mask_size); + }); + // If there's no True in mask, return directly + if (valid_num == 0) return; + + if (inputs.size() == 3U) { + if (inputs[2].shape_.Size() != 1) { + // tensor case, check tensor size with the valid_num + CHECK_EQ(static_cast(valid_num), inputs[2].shape_.Size()) + << "boolean array indexing assignment cannot assign " << inputs[2].shape_.Size() + << " input values to the " << valid_num << " output values where the mask is true" + << std::endl; + } + } + + size_t leading = 1U; + size_t middle = mask_size; + size_t trailing = 1U; + + if (inputs.size() == 3U) { + MSHADOW_TYPE_SWITCH(data.type_flag_, DType, { + if (inputs[2].shape_.Size() == 1) { + Kernel, cpu>::Launch( + s, valid_num, data.dptr(), prefix_sum.data(), prefix_sum.size(), + leading, middle, trailing, inputs[2].dptr()); + } else { + Kernel, cpu>::Launch( + s, valid_num, data.dptr(), prefix_sum.data(), prefix_sum.size(), + leading, middle, trailing, inputs[2].dptr()); + } + }); + } else { + CHECK(attrs.dict.find("value") != attrs.dict.end()) + << "value needs be provided"; + MSHADOW_TYPE_SWITCH(data.type_flag_, DType, { + Kernel, cpu>::Launch( + s, valid_num, data.dptr(), prefix_sum.data(), prefix_sum.size(), + leading, middle, trailing, static_cast(std::stod(attrs.dict.at("value")))); + }); + } +} + +NNVM_REGISTER_OP(_npi_boolean_mask_assign_scalar) +.describe(R"code(Scalar version of boolean assign)code" ADD_FILELINE) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "mask"}; +}) +.set_attr("FInferShape", BooleanAssignShape) +.set_attr("FInferType", BooleanAssignType) +.set_attr("FCompute", NumpyBooleanAssignForwardCPU) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}}; + }) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("data", "NDArray-or-Symbol", "input") +.add_argument("mask", "NDArray-or-Symbol", "mask") +.add_argument("value", "float", "value to be assigned to masked positions"); + +NNVM_REGISTER_OP(_npi_boolean_mask_assign_tensor) +.describe(R"code(Tensor version of boolean assign)code" ADD_FILELINE) +.set_num_inputs(3) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data", "mask", "value"}; +}) +.set_attr("FInferShape", BooleanAssignShape) +.set_attr("FInferType", BooleanAssignType) +.set_attr("FCompute", NumpyBooleanAssignForwardCPU) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs) { + return std::vector >{{0, 0}}; + }) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("data", "NDArray-or-Symbol", "input") +.add_argument("mask", "NDArray-or-Symbol", "mask") +.add_argument("value", "NDArray-or-Symbol", "assignment"); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_boolean_mask_assign.cu b/src/operator/numpy/np_boolean_mask_assign.cu new file mode 100644 index 000000000000..935fd30e195f --- /dev/null +++ b/src/operator/numpy/np_boolean_mask_assign.cu @@ -0,0 +1,213 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_boolean_mask_assign.cu + * \brief GPU implementation of Boolean Mask Assign + */ + +#include "../contrib/boolean_mask-inl.h" +#include + +namespace mxnet { +namespace op { + +struct BooleanAssignGPUKernel { + private: + static size_t __device__ bin_search(const size_t* idx, + const size_t idx_size, + const size_t i) { + size_t left = 0, right = idx_size, mid = (left + right) / 2; + while (left != right) { + if (idx[mid] == i + 1) { + if (idx[mid - 1] == i) { + mid -= 1; + break; + } else if (idx[mid - 1] == i + 1) { + right = mid; + mid = (left + right) / 2; + } + } else if (idx[mid] == i) { + if (idx[mid + 1] == i + 1) { + break; + } else { + left = mid; + mid = (left + right + 1) / 2; + } + } else if (idx[mid] < i + 1) { + left = mid; + mid = (left + right + 1) / 2; + } else if (idx[mid] > i + 1) { + right = mid; + mid = (left + right) / 2; + } + } + return mid; + } + + public: + template + static void __device__ Map(int i, + DType* data, + const size_t* idx, + const size_t idx_size, + const size_t leading, + const size_t middle, + const size_t trailing, + const DType val) { + // binary search for the turning point + size_t m = i / trailing % middle; + size_t mid = bin_search(idx, idx_size, m); + // final answer is in mid + data[i + (mid - m) * trailing] = val; + } + + template + static void __device__ Map(int i, + DType* data, + const size_t* idx, + const size_t idx_size, + const size_t leading, + const size_t middle, + const size_t trailing, + DType* tensor) { + // binary search for the turning point + size_t m = i / trailing % middle; + size_t mid = bin_search(idx, idx_size, m); + // final answer is in mid + data[i + (mid - m) * trailing] = tensor[m]; + } +}; + +struct NonZeroWithCast { + template + static void __device__ Map(int i, OType* out, const IType* in) { + out[i] = (in[i]) ? OType(1) : OType(0); + } +}; + +// completing the prefix_sum vector and return the pointer to it +template +size_t* GetValidNumGPU(const OpContext &ctx, const DType *idx, const size_t idx_size) { + using namespace mshadow; + using namespace mxnet_op; + size_t* prefix_sum = nullptr; + void* d_temp_storage = nullptr; + size_t temp_storage_bytes = 0; + Stream* s = ctx.get_stream(); + + // Calculate total temporary memory size + cub::DeviceScan::ExclusiveSum(d_temp_storage, + temp_storage_bytes, + prefix_sum, + prefix_sum, + idx_size + 1, + Stream::GetStream(s)); + size_t buffer_size = (idx_size + 1) * sizeof(size_t); + temp_storage_bytes += buffer_size; + // Allocate memory on GPU and allocate pointer + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(temp_storage_bytes), s); + prefix_sum = reinterpret_cast(workspace.dptr_); + d_temp_storage = workspace.dptr_ + buffer_size; + + // Robustly set the bool values in mask + // TODO(haojin2): Get a more efficient way to preset the buffer + Kernel::Launch(s, idx_size + 1, prefix_sum); + if (!std::is_same::value) { + Kernel::Launch(s, idx_size, prefix_sum, idx); + } + + // Calculate prefix sum + cub::DeviceScan::ExclusiveSum(d_temp_storage, + temp_storage_bytes, + prefix_sum, + prefix_sum, + idx_size + 1, + Stream::GetStream(s)); + + return prefix_sum; +} + +void NumpyBooleanAssignForwardGPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + using namespace mxnet_op; + CHECK(inputs.size() == 2U || inputs.size() == 3U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + CHECK_EQ(req[0], kWriteInplace) + << "Only WriteInplace is supported for npi_boolean_assign"; + + Stream* s = ctx.get_stream(); + + const TBlob& data = inputs[0]; + const TBlob& mask = inputs[1]; + // Get valid_num + size_t mask_size = mask.shape_.Size(); + size_t valid_num = 0; + size_t* prefix_sum = nullptr; + MSHADOW_TYPE_SWITCH(mask.type_flag_, MType, { + prefix_sum = GetValidNumGPU(ctx, mask.dptr(), mask_size); + }); + CUDA_CALL(cudaMemcpy(&valid_num, &prefix_sum[mask_size - 1], sizeof(int32_t), + cudaMemcpyDeviceToHost)); + // If there's no True in mask, return directly + if (valid_num == 0) return; + + if (inputs.size() == 3U) { + // tensor case, check tensor size with the valid_num + CHECK_EQ(static_cast(valid_num), inputs[2].shape_.Size()) + << "boolean array indexing assignment cannot assign " << inputs[2].shape_.Size() + << " input values to the " << valid_num << "output values where the mask is true" + << std::endl; + } + + size_t leading = 1U; + size_t middle = mask_size; + size_t trailing = 1U; + + if (inputs.size() == 3U) { + MSHADOW_TYPE_SWITCH(data.type_flag_, DType, { + Kernel::Launch( + s, leading * valid_num * trailing, data.dptr(), prefix_sum, mask_size + 1, + leading, middle, trailing, inputs[2].dptr()); + }); + } else { + CHECK(attrs.dict.find("value") != attrs.dict.end()) + << "value is not provided"; + MSHADOW_TYPE_SWITCH(data.type_flag_, DType, { + Kernel::Launch( + s, leading * valid_num * trailing, data.dptr(), prefix_sum, mask_size + 1, + leading, middle, trailing, static_cast(std::stod(attrs.dict.at("value")))); + }); + } +} + +NNVM_REGISTER_OP(_npi_boolean_mask_assign_scalar) +.set_attr("FCompute", NumpyBooleanAssignForwardGPU); + +NNVM_REGISTER_OP(_npi_boolean_mask_assign_tensor) +.set_attr("FCompute", NumpyBooleanAssignForwardGPU); + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index eaf3032d526d..224125eb64f3 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -647,6 +647,42 @@ def hybrid_forward(self, F, a): assert same(a.grad.asnumpy(), expected_grad) +@with_seed() +@use_np +def test_npi_boolean_assign(): + class TestBooleanAssignScalar(HybridBlock): + def __init__(self, val): + super(TestBooleanAssignScalar, self).__init__() + self._val = val + + def hybrid_forward(self, F, a, mask): + return F.np._internal.boolean_mask_assign_scalar(a, mask, self._val, out=a) + + class TestBooleanAssignTensor(HybridBlock): + def __init__(self): + super(TestBooleanAssignTensor, self).__init__() + + def hybrid_forward(self, F, a, mask, value): + return F.np._internal.boolean_mask_assign_tensor(a, mask, value, out=a) + + shapes = [(3, 4), (3, 0), ()] + for hybridize in [False]: + for shape in shapes: + test_data = np.random.uniform(size=shape) + mx_mask = np.around(np.random.uniform(size=shape)) + valid_num = int(mx_mask.sum()) + np_mask = mx_mask.asnumpy().astype(_np.bool) + for val in [42., np.array(42.), np.array([42.]), np.random.uniform(size=(valid_num,))]: + test_block = TestBooleanAssignScalar(val) if isinstance(val, float) else TestBooleanAssignTensor() + if hybridize: + test_block.hybridize() + np_data = test_data.asnumpy() + mx_data = test_data.copy() + np_data[np_mask] = val + mx_data = test_block(mx_data, mx_mask) if isinstance(val, float) else test_block(mx_data, mx_mask, val) + assert_almost_equal(mx_data.asnumpy(), np_data, rtol=1e-3, atol=1e-5, use_broadcast=False) + + @with_seed() @use_np def test_np_reshape():