From f40484677c23368068c0ff75ae7910b5ca72461e Mon Sep 17 00:00:00 2001 From: Minghao Liu Date: Fri, 28 Feb 2020 04:27:15 +0000 Subject: [PATCH] ffi_bincount --- python/mxnet/ndarray/numpy/_op.py | 6 +-- src/api/operator/numpy/np_bincount_op.cc | 56 ++++++++++++++++++++++++ src/operator/numpy/np_bincount_op-inl.h | 8 ++++ 3 files changed, 65 insertions(+), 5 deletions(-) create mode 100644 src/api/operator/numpy/np_bincount_op.cc diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 242931862988..d465057dd726 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -7341,13 +7341,9 @@ def bincount(x, weights=None, minlength=0): >>> np.bincount(x, weights=w) array([ 0.3, 0.7, 1.1]) """ - if not isinstance(x, NDArray): - raise TypeError("Input data should be NDarray") if minlength < 0: raise ValueError("Minlength value should greater than 0") - if weights is None: - return _npi.bincount(x, minlength=minlength, has_weights=False) - return _npi.bincount(x, weights=weights, minlength=minlength, has_weights=True) + return _api_internal.bincount(x, weights, minlength) @set_module('mxnet.ndarray.numpy') diff --git a/src/api/operator/numpy/np_bincount_op.cc b/src/api/operator/numpy/np_bincount_op.cc new file mode 100644 index 000000000000..4baebdeb8223 --- /dev/null +++ b/src/api/operator/numpy/np_bincount_op.cc @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_bincount_op.cc + * \brief Implementation of the API of functions in src/operator/numpy/np_bincount_op.cc + */ +#include "../utils.h" +#include "../../../operator/numpy/np_bincount_op-inl.h" + +namespace mxnet { + +MXNET_REGISTER_API("_npi.bincount") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_bincount"); + nnvm::NodeAttrs attrs; + op::NumpyBincountParam param; + + int num_outputs = 0; + if (args[1].type_code() == kNull) { + param.minlength = args[2].operator int64_t(); + param.has_weights = false; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + attrs.parsed = param; // remove std::move() ci error trivially-copyable type + attrs.op = op; + auto ndoutputs = Invoke(op, &attrs, 1, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); + } else { + param.minlength = args[2].operator int64_t(); + param.has_weights = true; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; + attrs.parsed = param; // remove std::move() ci error trivially-copyable type + attrs.op = op; + auto ndoutputs = Invoke(op, &attrs, 2, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); + } +}); + +} // namespace mxnet diff --git a/src/operator/numpy/np_bincount_op-inl.h b/src/operator/numpy/np_bincount_op-inl.h index 254ea8fdec22..a2e758f36059 100644 --- a/src/operator/numpy/np_bincount_op-inl.h +++ b/src/operator/numpy/np_bincount_op-inl.h @@ -28,6 +28,7 @@ #include #include #include +#include #include "../mshadow_op.h" #include "../mxnet_op.h" #include "../operator_common.h" @@ -50,6 +51,13 @@ struct NumpyBincountParam : public dmlc::Parameter { .set_default(false) .describe("Determine whether Bincount has weights."); } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream minlength_s, has_weights_s; + minlength_s << minlength; + has_weights_s << has_weights; + (*dict)["minlength"] = minlength_s.str(); + (*dict)["has_weights"] = has_weights_s.str(); + } }; inline bool NumpyBincountType(const nnvm::NodeAttrs& attrs,