From 97cb5d64e3379396b9bd6d53da5e0123fe7bb8e7 Mon Sep 17 00:00:00 2001 From: Ying Date: Thu, 1 Aug 2019 21:36:17 +0800 Subject: [PATCH] numpy operator nonzero * add cpu test and handle 0-dim * add FGradient with MakeZeroGradNodes * handle 0-dim and 0-shape and add test on gpu * add doc * fix bug in review * do not use thrust::inclusive_scan on cpu * fix format error * edit test and remove gpu test The output is same as numpy.transpose(numpy.nonzero(x)) --- python/mxnet/_numpy_op_doc.py | 49 ++++++++++ src/operator/numpy/np_nonzero_op-inl.h | 65 +++++++++++++ src/operator/numpy/np_nonzero_op.cc | 129 ++++++++++++++++++++++++ src/operator/numpy/np_nonzero_op.cu | 130 +++++++++++++++++++++++++ tests/python/unittest/test_numpy_op.py | 38 ++++++++ 5 files changed, 411 insertions(+) create mode 100644 src/operator/numpy/np_nonzero_op-inl.h create mode 100644 src/operator/numpy/np_nonzero_op.cc create mode 100644 src/operator/numpy/np_nonzero_op.cu diff --git a/python/mxnet/_numpy_op_doc.py b/python/mxnet/_numpy_op_doc.py index 5543ebc8e8c9..6052382872fe 100644 --- a/python/mxnet/_numpy_op_doc.py +++ b/python/mxnet/_numpy_op_doc.py @@ -52,3 +52,52 @@ def _np_zeros_like(a): Array of zeros with the same shape and type as `a`. """ pass + + +def _npx_nonzero(a): + """ + nonzero(a) + + Return the indices of the elements that are non-zero. + + Returns a ndarray with ndim is 2. Each row contains the indices + of the non-zero elements. The values in `a` are always tested and returned in + row-major, C-style order. + + The result of this is always a 2-D array, with a row for + each non-zero element. + + Parameters + ---------- + a : array_like + Input array. + + Returns + ------- + array : ndarray + Indices of elements that are non-zero. + + Notes + ----- + This function differs from the original numpy.prod in the following aspects: + - Do not support python numeric. + - The return value is same as numpy.transpose(numpy.nonzero(a)). + + Examples + -------- + >>> x = np.array([[3, 0, 0], [0, 4, 0], [5, 6, 0]]) + >>> x + array([[3, 0, 0], + [0, 4, 0], + [5, 6, 0]]) + >>> npx.nonzero(x) + array([[0, 0], + [1, 1], + [2, 0], + [2, 1]], dtype=int64) + + >>> np.transpose(npx.nonzero(x)) + array([[0, 1, 2, 2], + [0, 1, 0, 1]], dtype=int64) + """ + pass diff --git a/src/operator/numpy/np_nonzero_op-inl.h b/src/operator/numpy/np_nonzero_op-inl.h new file mode 100644 index 000000000000..88929c43e5b5 --- /dev/null +++ b/src/operator/numpy/np_nonzero_op-inl.h @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * Copyright (c) 2018 by Contributors + * \file np_nonzero_op-inl.h +*/ + +#ifndef MXNET_OPERATOR_NUMPY_NP_NONZERO_OP_INL_H_ +#define MXNET_OPERATOR_NUMPY_NP_NONZERO_OP_INL_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "../operator_common.h" +#include "../mxnet_op.h" +#include "../tensor/init_op.h" +#include "../mshadow_op.h" +#include "../elemwise_op_common.h" + +namespace mxnet { +namespace op { + +struct NonzeroForwardKernel { + template + MSHADOW_XINLINE static void Map(int i, + int64_t* out, + const int32_t* idx, + const mshadow::Shape shape) { + int32_t prev = (i == 0) ? 0 : idx[i - 1]; + int32_t curr = idx[i]; + if (prev != curr) { + mshadow::Shape coord = mxnet_op::unravel(i, shape); + for (int j = 0; j < ndim; j++) { + out[prev * ndim + j] = coord[j]; + } + } + } +}; + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_NP_NONZERO_OP_INL_H_ diff --git a/src/operator/numpy/np_nonzero_op.cc b/src/operator/numpy/np_nonzero_op.cc new file mode 100644 index 000000000000..00f9081ba984 --- /dev/null +++ b/src/operator/numpy/np_nonzero_op.cc @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * Copyright (c) 2018 by Contributors + * \file np_nonzero_op.cc +*/ +#include "np_nonzero_op-inl.h" + +namespace mxnet { +namespace op { + +bool NonzeroType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1); + CHECK_EQ(out_attrs->size(), 1); + // Output must be int64. + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64); + return out_attrs->at(0) != -1; +} + +#define MAXDIM 5 + +bool NonzeroStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1); + CHECK_EQ(out_attrs->size(), 1); + for (int &attr : *in_attrs) { + CHECK_EQ(attr, kDefaultStorage) << "Only default storage is supported"; + } + for (int &attr : *out_attrs) { + attr = kDefaultStorage; + } + *dispatch_mode = DispatchMode::kFComputeEx; + return true; +} + +void NonzeroForwardCPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + const NDArray &in = inputs[0]; + const NDArray &out = outputs[0]; + CHECK_LE(in.shape().ndim(), MAXDIM) << "ndim of input cannot larger than " << MAXDIM; + // 0-dim + if (0 == in.shape().ndim()) { + MSHADOW_TYPE_SWITCH(in.dtype(), DType, { + DType* in_dptr = in.data().dptr(); + if (*in_dptr) { + mxnet::TShape s(2, 1); + const_cast(out).Init(s); + *(out.data().dptr()) = 0; + } else { + mxnet::TShape s(2, 1); + s[0] = 0; + const_cast(out).Init(s); + } + }); + return; + } + size_t in_size = in.shape().Size(); + // 0-shape + if (0 == in_size) { + mxnet::TShape s(2, in.shape().ndim()); + s[0] = 0; + const_cast(out).Init(s); + return; + } + std::vector prefix_sum(in_size, 0); + size_t valid_num = 0; + // Calculate prefix sum + MSHADOW_TYPE_SWITCH(in.dtype(), DType, { + DType* in_dptr = in.data().dptr(); + for (size_t i = 0; i < in_size; i++) { + prefix_sum[i] = (i == 0) ? 0 : prefix_sum[i - 1]; + prefix_sum[i] += (in_dptr[i]) ? 1 : 0; + } + }); + valid_num = prefix_sum[in_size - 1]; + // set the output shape forcefully + mxnet::TShape s(2, in.shape().ndim()); + s[0] = valid_num; + const_cast(out).Init(s); + // get the shape from the input + MXNET_NDIM_SWITCH(in.shape().ndim(), ndim, { + mshadow::Shape shape = in.shape().get(); + mshadow::Stream *stream = ctx.get_stream(); + mxnet_op::Kernel::Launch( + stream, in_size, out.data().dptr(), prefix_sum.data(), shape); + }) +} + +NNVM_REGISTER_OP(_npx_nonzero) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"x"}; + }) +.set_attr("FInferType", NonzeroType) +.set_attr("FComputeEx", NonzeroForwardCPU) +.set_attr("FInferStorageType", NonzeroStorageType) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("x", "NDArray-or-Symbol", "The input array."); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_nonzero_op.cu b/src/operator/numpy/np_nonzero_op.cu new file mode 100644 index 000000000000..33925ea2e156 --- /dev/null +++ b/src/operator/numpy/np_nonzero_op.cu @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * Copyright (c) 2018 by Contributors + * \file np_nonzero_op.cu +*/ + +#include "np_nonzero_op-inl.h" +#include + +namespace mxnet { +namespace op { + +struct PrefixSumInit { + template + MSHADOW_XINLINE static void Map(int i, + int32_t* out, + DType* in) { + if (in[i]) { + out[i] = 1; + } else { + out[i] = 0; + } + } +}; + +#define MAXDIM 5 + +void NonzeroForwardGPU(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mshadow; + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + const NDArray &in = inputs[0]; + const NDArray &out = outputs[0]; + CHECK_LE(in.shape().ndim(), MAXDIM) << "ndim of input cannot larger than " << MAXDIM; + size_t in_size = in.shape().Size(); + // 0-shape + if (0 == in_size) { + mxnet::TShape s(2, in.shape().ndim()); + s[0] = 0; + const_cast(out).Init(s); + return; + } + int32_t valid_num = 0; + Stream* stream = ctx.get_stream(); + int32_t* prefix_sum = nullptr; + void* d_temp_storage = nullptr; + size_t temp_storage_bytes = 0; + // Calculate total temporary memory size + cub::DeviceScan::InclusiveSum(d_temp_storage, + temp_storage_bytes, + prefix_sum, + prefix_sum, + in_size, + Stream::GetStream(stream)); + size_t buffer_size = in_size * sizeof(int32_t); + temp_storage_bytes += buffer_size; + // Allocate memory on GPU and allocate pointer + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(temp_storage_bytes), stream); + prefix_sum = reinterpret_cast(workspace.dptr_); + d_temp_storage = workspace.dptr_ + buffer_size; + MSHADOW_TYPE_SWITCH(in.dtype(), DType, { + mxnet_op::Kernel::Launch( + stream, in_size, prefix_sum, in.data().dptr()); + }); + // Calculate prefix sum + cub::DeviceScan::InclusiveSum(d_temp_storage, + temp_storage_bytes, + prefix_sum, + prefix_sum, + in_size, + Stream::GetStream(stream)); + CUDA_CALL(cudaMemcpy(&valid_num, &prefix_sum[in_size - 1], sizeof(int32_t), + cudaMemcpyDeviceToHost)); + // 0-dim + if (0 == in.shape().ndim()) { + mxnet::TShape s(2, 1); + if (valid_num) { + const_cast(out).Init(s); + int64_t temp = 0; + CUDA_CALL(cudaMemcpy(out.data().dptr(), &temp, sizeof(int64_t), + cudaMemcpyHostToDevice)); + } else { + s[0] = 0; + const_cast(out).Init(s); + } + return; + } + // Set the output shape forcefully + mxnet::TShape s(2, in.shape().ndim()); + s[0] = valid_num; + const_cast(out).Init(s); + // get the shape from the input + MXNET_NDIM_SWITCH(in.shape().ndim(), ndim, { + mshadow::Shape shape = in.shape().get(); + mxnet_op::Kernel::Launch( + stream, in_size, out.data().dptr(), prefix_sum, shape); + }) +} + +NNVM_REGISTER_OP(_npx_nonzero) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) +.set_attr("FComputeEx", NonzeroForwardGPU); + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index b179f67e6128..7de5caa28526 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -92,6 +92,44 @@ def is_int(dtype): assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) +@with_seed() +@use_np +def test_np_nonzero(): + class TestNonzero(HybridBlock): + def __init__(self): + super(TestNonzero, self).__init__() + + def hybrid_forward(self, F, x): + return F.npx.nonzero(x) + + types = ['int32', 'int64', 'float64', 'float32', 'float16'] + for hybridize in [True, False]: + for shape in [(), + (1,), + (1, 1), + (1, 2, 3), + (1, 0), + (2, 0, 3) + ]: + for oneType in types: + rtol=1e-3 + atol=1e-5 + test_nonzero = TestNonzero() + if hybridize: + test_nonzero.hybridize() + x = rand_ndarray(shape, dtype=oneType).as_np_ndarray() + np_out = _np.nonzero(x.asnumpy()) + np_out = _np.transpose(np_out) + mx_out = test_nonzero(x) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, rtol, atol) + + # Test imperative once again + mx_out = npx.nonzero(x) + np_out = _np.nonzero(x.asnumpy()) + np_out = _np.transpose(np_out) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol, atol) + if __name__ == '__main__': import nose nose.runmodule()