Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[numpy] nonzero #15838

Merged
merged 1 commit into from
Sep 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions python/mxnet/_numpy_op_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,55 @@ def _np_cumsum(a, axis=None, dtype=None, out=None):
>>> np.cumsum(a,axis=1) # sum over columns for each of the 2 rows
array([[ 1, 3, 6],
[ 4, 9, 15]])
"""
pass


def _npx_nonzero(a):
haojin2 marked this conversation as resolved.
Show resolved Hide resolved
"""
nonzero(a)

Return the indices of the elements that are non-zero.

Returns a ndarray with ndim is 2. Each row contains the indices
of the non-zero elements. The values in `a` are always tested and returned in
row-major, C-style order.

The result of this is always a 2-D array, with a row for
each non-zero element.

Parameters
----------
a : array_like
Input array.

Returns
-------
array : ndarray
Indices of elements that are non-zero.

Notes
-----
This function differs from the original numpy.prod in the following aspects:
- Do not support python numeric.
- The return value is same as numpy.transpose(numpy.nonzero(a)).

Examples
--------
>>> x = np.array([[3, 0, 0], [0, 4, 0], [5, 6, 0]])
>>> x
array([[3, 0, 0],
[0, 4, 0],
[5, 6, 0]])
>>> npx.nonzero(x)
array([[0, 0],
[1, 1],
[2, 0],
[2, 1]], dtype=int64)

>>> np.transpose(npx.nonzero(x))
array([[0, 1, 2, 2],
[0, 1, 0, 1]], dtype=int64)
"""
pass

Expand Down
65 changes: 65 additions & 0 deletions src/operator/numpy/np_nonzero_op-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file np_nonzero_op-inl.h
*/

#ifndef MXNET_OPERATOR_NUMPY_NP_NONZERO_OP_INL_H_
#define MXNET_OPERATOR_NUMPY_NP_NONZERO_OP_INL_H_

#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
#include <mxnet/ndarray.h>
#include <map>
#include <vector>
#include <string>
#include <utility>
#include <algorithm>
#include "../operator_common.h"
#include "../mxnet_op.h"
#include "../tensor/init_op.h"
#include "../mshadow_op.h"
#include "../elemwise_op_common.h"

namespace mxnet {
namespace op {

struct NonzeroForwardKernel {
template<int ndim>
MSHADOW_XINLINE static void Map(int i,
int64_t* out,
const int32_t* idx,
const mshadow::Shape<ndim> shape) {
int32_t prev = (i == 0) ? 0 : idx[i - 1];
int32_t curr = idx[i];
if (prev != curr) {
mshadow::Shape<ndim> coord = mxnet_op::unravel<ndim>(i, shape);
for (int j = 0; j < ndim; j++) {
out[prev * ndim + j] = coord[j];
}
}
}
};

} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_NUMPY_NP_NONZERO_OP_INL_H_
129 changes: 129 additions & 0 deletions src/operator/numpy/np_nonzero_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file np_nonzero_op.cc
*/
#include "np_nonzero_op-inl.h"

namespace mxnet {
namespace op {

bool NonzeroType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1);
CHECK_EQ(out_attrs->size(), 1);
// Output must be int64.
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64);
return out_attrs->at(0) != -1;
}

#define MAXDIM 5

bool NonzeroStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1);
CHECK_EQ(out_attrs->size(), 1);
for (int &attr : *in_attrs) {
CHECK_EQ(attr, kDefaultStorage) << "Only default storage is supported";
}
for (int &attr : *out_attrs) {
attr = kDefaultStorage;
}
*dispatch_mode = DispatchMode::kFComputeEx;
return true;
}

void NonzeroForwardCPU(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
const NDArray &in = inputs[0];
const NDArray &out = outputs[0];
CHECK_LE(in.shape().ndim(), MAXDIM) << "ndim of input cannot larger than " << MAXDIM;
// 0-dim
if (0 == in.shape().ndim()) {
MSHADOW_TYPE_SWITCH(in.dtype(), DType, {
DType* in_dptr = in.data().dptr<DType>();
if (*in_dptr) {
mxnet::TShape s(2, 1);
const_cast<NDArray &>(out).Init(s);
*(out.data().dptr<int64_t>()) = 0;
} else {
mxnet::TShape s(2, 1);
s[0] = 0;
const_cast<NDArray &>(out).Init(s);
}
});
return;
}
size_t in_size = in.shape().Size();
// 0-shape
if (0 == in_size) {
mxnet::TShape s(2, in.shape().ndim());
s[0] = 0;
const_cast<NDArray &>(out).Init(s);
return;
}
std::vector<int32_t> prefix_sum(in_size, 0);
size_t valid_num = 0;
// Calculate prefix sum
MSHADOW_TYPE_SWITCH(in.dtype(), DType, {
DType* in_dptr = in.data().dptr<DType>();
for (size_t i = 0; i < in_size; i++) {
prefix_sum[i] = (i == 0) ? 0 : prefix_sum[i - 1];
prefix_sum[i] += (in_dptr[i]) ? 1 : 0;
}
});
valid_num = prefix_sum[in_size - 1];
// set the output shape forcefully
mxnet::TShape s(2, in.shape().ndim());
s[0] = valid_num;
const_cast<NDArray &>(out).Init(s);
// get the shape from the input
MXNET_NDIM_SWITCH(in.shape().ndim(), ndim, {
mshadow::Shape<ndim> shape = in.shape().get<ndim>();
mshadow::Stream<cpu> *stream = ctx.get_stream<cpu>();
mxnet_op::Kernel<NonzeroForwardKernel, cpu>::Launch(
stream, in_size, out.data().dptr<int64_t>(), prefix_sum.data(), shape);
})
}

NNVM_REGISTER_OP(_npx_nonzero)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"x"};
})
.set_attr<nnvm::FInferType>("FInferType", NonzeroType)
.set_attr<FComputeEx>("FComputeEx<cpu>", NonzeroForwardCPU)
.set_attr<FInferStorageType>("FInferStorageType", NonzeroStorageType)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("x", "NDArray-or-Symbol", "The input array.");

} // namespace op
} // namespace mxnet
130 changes: 130 additions & 0 deletions src/operator/numpy/np_nonzero_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file np_nonzero_op.cu
*/

#include "np_nonzero_op-inl.h"
#include <cub/cub.cuh>

namespace mxnet {
namespace op {

struct PrefixSumInit {
template<typename DType>
MSHADOW_XINLINE static void Map(int i,
int32_t* out,
DType* in) {
if (in[i]) {
out[i] = 1;
} else {
out[i] = 0;
}
}
};

#define MAXDIM 5

void NonzeroForwardGPU(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
using namespace mshadow;
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
const NDArray &in = inputs[0];
const NDArray &out = outputs[0];
CHECK_LE(in.shape().ndim(), MAXDIM) << "ndim of input cannot larger than " << MAXDIM;
size_t in_size = in.shape().Size();
// 0-shape
if (0 == in_size) {
mxnet::TShape s(2, in.shape().ndim());
s[0] = 0;
const_cast<NDArray &>(out).Init(s);
return;
}
int32_t valid_num = 0;
Stream<gpu>* stream = ctx.get_stream<gpu>();
int32_t* prefix_sum = nullptr;
void* d_temp_storage = nullptr;
size_t temp_storage_bytes = 0;
// Calculate total temporary memory size
cub::DeviceScan::InclusiveSum(d_temp_storage,
temp_storage_bytes,
prefix_sum,
prefix_sum,
in_size,
Stream<gpu>::GetStream(stream));
size_t buffer_size = in_size * sizeof(int32_t);
temp_storage_bytes += buffer_size;
// Allocate memory on GPU and allocate pointer
Tensor<gpu, 1, char> workspace =
ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(temp_storage_bytes), stream);
prefix_sum = reinterpret_cast<int32_t*>(workspace.dptr_);
d_temp_storage = workspace.dptr_ + buffer_size;
MSHADOW_TYPE_SWITCH(in.dtype(), DType, {
mxnet_op::Kernel<PrefixSumInit, gpu>::Launch(
stream, in_size, prefix_sum, in.data().dptr<DType>());
});
// Calculate prefix sum
cub::DeviceScan::InclusiveSum(d_temp_storage,
temp_storage_bytes,
prefix_sum,
prefix_sum,
in_size,
Stream<gpu>::GetStream(stream));
CUDA_CALL(cudaMemcpy(&valid_num, &prefix_sum[in_size - 1], sizeof(int32_t),
cudaMemcpyDeviceToHost));
// 0-dim
if (0 == in.shape().ndim()) {
mxnet::TShape s(2, 1);
if (valid_num) {
const_cast<NDArray &>(out).Init(s);
int64_t temp = 0;
CUDA_CALL(cudaMemcpy(out.data().dptr<int64_t>(), &temp, sizeof(int64_t),
cudaMemcpyHostToDevice));
} else {
s[0] = 0;
const_cast<NDArray &>(out).Init(s);
}
return;
}
// Set the output shape forcefully
mxnet::TShape s(2, in.shape().ndim());
s[0] = valid_num;
const_cast<NDArray &>(out).Init(s);
// get the shape from the input
MXNET_NDIM_SWITCH(in.shape().ndim(), ndim, {
mshadow::Shape<ndim> shape = in.shape().get<ndim>();
mxnet_op::Kernel<NonzeroForwardKernel, gpu>::Launch(
stream, in_size, out.data().dptr<int64_t>(), prefix_sum, shape);
})
}

NNVM_REGISTER_OP(_npx_nonzero)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FComputeEx>("FComputeEx<gpu>", NonzeroForwardGPU);

} // namespace op
} // namespace mxnet
Loading