Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
numpy operator nonzero
Browse files Browse the repository at this point in the history
* add cpu test and handle 0-dim

* add FGradient with MakeZeroGradNodes

* handle 0-dim and 0-shape and add test on gpu

* add doc

* fix bug in review

* do not use thrust::inclusive_scan on cpu

* fix format error

* edit test and remove gpu test

The output is same as numpy.transpose(numpy.nonzero(x))

* fix error of review

* edit test
  • Loading branch information
Ying authored and tingying2020 committed Sep 19, 2019
1 parent a37a76c commit eb41507
Show file tree
Hide file tree
Showing 5 changed files with 404 additions and 0 deletions.
48 changes: 48 additions & 0 deletions python/mxnet/_numpy_op_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,55 @@ def _np_cumsum(a, axis=None, dtype=None, out=None):
>>> np.cumsum(a,axis=1) # sum over columns for each of the 2 rows
array([[ 1, 3, 6],
[ 4, 9, 15]])
"""
pass


def _npx_nonzero(a):
"""
nonzero(a)
Return the indices of the elements that are non-zero.
Returns a ndarray with ndim is 2. Each row contains the indices
of the non-zero elements. The values in `a` are always tested and returned in
row-major, C-style order.
The result of this is always a 2-D array, with a row for
each non-zero element.
Parameters
----------
a : array_like
Input array.
Returns
-------
array : ndarray
Indices of elements that are non-zero.
Notes
-----
This function differs from the original numpy.prod in the following aspects:
- Do not support python numeric.
- The return value is same as numpy.transpose(numpy.nonzero(a)).
Examples
--------
>>> x = np.array([[3, 0, 0], [0, 4, 0], [5, 6, 0]])
>>> x
array([[3, 0, 0],
[0, 4, 0],
[5, 6, 0]])
>>> npx.nonzero(x)
array([[0, 0],
[1, 1],
[2, 0],
[2, 1]], dtype=int64)
>>> np.transpose(npx.nonzero(x))
array([[0, 1, 2, 2],
[0, 1, 0, 1]], dtype=int64)
"""
pass

Expand Down
65 changes: 65 additions & 0 deletions src/operator/numpy/np_nonzero_op-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file np_nonzero_op-inl.h
*/

#ifndef MXNET_OPERATOR_NUMPY_NP_NONZERO_OP_INL_H_
#define MXNET_OPERATOR_NUMPY_NP_NONZERO_OP_INL_H_

#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
#include <mxnet/ndarray.h>
#include <map>
#include <vector>
#include <string>
#include <utility>
#include <algorithm>
#include "../operator_common.h"
#include "../mxnet_op.h"
#include "../tensor/init_op.h"
#include "../mshadow_op.h"
#include "../elemwise_op_common.h"

namespace mxnet {
namespace op {

struct NonzeroForwardKernel {
template<int ndim>
MSHADOW_XINLINE static void Map(int i,
int64_t* out,
const int32_t* idx,
const mshadow::Shape<ndim> shape) {
int32_t prev = (i == 0) ? 0 : idx[i - 1];
int32_t curr = idx[i];
if (prev != curr) {
mshadow::Shape<ndim> coord = mxnet_op::unravel<ndim>(i, shape);
for (int j = 0; j < ndim; j++) {
out[prev * ndim + j] = coord[j];
}
}
}
};

} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_NUMPY_NP_NONZERO_OP_INL_H_
129 changes: 129 additions & 0 deletions src/operator/numpy/np_nonzero_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file np_nonzero_op.cc
*/
#include "np_nonzero_op-inl.h"

namespace mxnet {
namespace op {

bool NonzeroType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1);
CHECK_EQ(out_attrs->size(), 1);
// Output must be int64.
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64);
return out_attrs->at(0) != -1;
}

#define MAXDIM 5

bool NonzeroStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 1);
CHECK_EQ(out_attrs->size(), 1);
for (int &attr : *in_attrs) {
CHECK_EQ(attr, kDefaultStorage) << "Only default storage is supported";
}
for (int &attr : *out_attrs) {
attr = kDefaultStorage;
}
*dispatch_mode = DispatchMode::kFComputeEx;
return true;
}

void NonzeroForwardCPU(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
const NDArray &in = inputs[0];
const NDArray &out = outputs[0];
CHECK_LE(in.shape().ndim(), MAXDIM) << "ndim of input cannot larger than " << MAXDIM;
// 0-dim
if (0 == in.shape().ndim()) {
MSHADOW_TYPE_SWITCH(in.dtype(), DType, {
DType* in_dptr = in.data().dptr<DType>();
if (*in_dptr) {
mxnet::TShape s(2, 1);
const_cast<NDArray &>(out).Init(s);
*(out.data().dptr<int64_t>()) = 0;
} else {
mxnet::TShape s(2, 1);
s[0] = 0;
const_cast<NDArray &>(out).Init(s);
}
});
return;
}
size_t in_size = in.shape().Size();
// 0-shape
if (0 == in_size) {
mxnet::TShape s(2, in.shape().ndim());
s[0] = 0;
const_cast<NDArray &>(out).Init(s);
return;
}
std::vector<int32_t> prefix_sum(in_size, 0);
size_t valid_num = 0;
// Calculate prefix sum
MSHADOW_TYPE_SWITCH(in.dtype(), DType, {
DType* in_dptr = in.data().dptr<DType>();
for (size_t i = 0; i < in_size; i++) {
prefix_sum[i] = (i == 0) ? 0 : prefix_sum[i - 1];
prefix_sum[i] += (in_dptr[i]) ? 1 : 0;
}
});
valid_num = prefix_sum[in_size - 1];
// set the output shape forcefully
mxnet::TShape s(2, in.shape().ndim());
s[0] = valid_num;
const_cast<NDArray &>(out).Init(s);
// get the shape from the input
MXNET_NDIM_SWITCH(in.shape().ndim(), ndim, {
mshadow::Shape<ndim> shape = in.shape().get<ndim>();
mshadow::Stream<cpu> *stream = ctx.get_stream<cpu>();
mxnet_op::Kernel<NonzeroForwardKernel, cpu>::Launch(
stream, in_size, out.data().dptr<int64_t>(), prefix_sum.data(), shape);
})
}

NNVM_REGISTER_OP(_npx_nonzero)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"x"};
})
.set_attr<nnvm::FInferType>("FInferType", NonzeroType)
.set_attr<FComputeEx>("FComputeEx<cpu>", NonzeroForwardCPU)
.set_attr<FInferStorageType>("FInferStorageType", NonzeroStorageType)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.add_argument("x", "NDArray-or-Symbol", "The input array.");

} // namespace op
} // namespace mxnet
130 changes: 130 additions & 0 deletions src/operator/numpy/np_nonzero_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* Copyright (c) 2018 by Contributors
* \file np_nonzero_op.cu
*/

#include "np_nonzero_op-inl.h"
#include <cub/cub.cuh>

namespace mxnet {
namespace op {

struct PrefixSumInit {
template<typename DType>
MSHADOW_XINLINE static void Map(int i,
int32_t* out,
DType* in) {
if (in[i]) {
out[i] = 1;
} else {
out[i] = 0;
}
}
};

#define MAXDIM 5

void NonzeroForwardGPU(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
using namespace mshadow;
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
const NDArray &in = inputs[0];
const NDArray &out = outputs[0];
CHECK_LE(in.shape().ndim(), MAXDIM) << "ndim of input cannot larger than " << MAXDIM;
size_t in_size = in.shape().Size();
// 0-shape
if (0 == in_size) {
mxnet::TShape s(2, in.shape().ndim());
s[0] = 0;
const_cast<NDArray &>(out).Init(s);
return;
}
int32_t valid_num = 0;
Stream<gpu>* stream = ctx.get_stream<gpu>();
int32_t* prefix_sum = nullptr;
void* d_temp_storage = nullptr;
size_t temp_storage_bytes = 0;
// Calculate total temporary memory size
cub::DeviceScan::InclusiveSum(d_temp_storage,
temp_storage_bytes,
prefix_sum,
prefix_sum,
in_size,
Stream<gpu>::GetStream(stream));
size_t buffer_size = in_size * sizeof(int32_t);
temp_storage_bytes += buffer_size;
// Allocate memory on GPU and allocate pointer
Tensor<gpu, 1, char> workspace =
ctx.requested[0].get_space_typed<gpu, 1, char>(Shape1(temp_storage_bytes), stream);
prefix_sum = reinterpret_cast<int32_t*>(workspace.dptr_);
d_temp_storage = workspace.dptr_ + buffer_size;
MSHADOW_TYPE_SWITCH(in.dtype(), DType, {
mxnet_op::Kernel<PrefixSumInit, gpu>::Launch(
stream, in_size, prefix_sum, in.data().dptr<DType>());
});
// Calculate prefix sum
cub::DeviceScan::InclusiveSum(d_temp_storage,
temp_storage_bytes,
prefix_sum,
prefix_sum,
in_size,
Stream<gpu>::GetStream(stream));
CUDA_CALL(cudaMemcpy(&valid_num, &prefix_sum[in_size - 1], sizeof(int32_t),
cudaMemcpyDeviceToHost));
// 0-dim
if (0 == in.shape().ndim()) {
mxnet::TShape s(2, 1);
if (valid_num) {
const_cast<NDArray &>(out).Init(s);
int64_t temp = 0;
CUDA_CALL(cudaMemcpy(out.data().dptr<int64_t>(), &temp, sizeof(int64_t),
cudaMemcpyHostToDevice));
} else {
s[0] = 0;
const_cast<NDArray &>(out).Init(s);
}
return;
}
// Set the output shape forcefully
mxnet::TShape s(2, in.shape().ndim());
s[0] = valid_num;
const_cast<NDArray &>(out).Init(s);
// get the shape from the input
MXNET_NDIM_SWITCH(in.shape().ndim(), ndim, {
mshadow::Shape<ndim> shape = in.shape().get<ndim>();
mxnet_op::Kernel<NonzeroForwardKernel, gpu>::Launch(
stream, in_size, out.data().dptr<int64_t>(), prefix_sum, shape);
})
}

NNVM_REGISTER_OP(_npx_nonzero)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FComputeEx>("FComputeEx<gpu>", NonzeroForwardGPU);

} // namespace op
} // namespace mxnet
Loading

0 comments on commit eb41507

Please sign in to comment.