diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 538d5202942d..376d67714bb7 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -40,7 +40,7 @@ 'blackman', 'flip', 'around', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', - 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where'] + 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where', 'bincount'] @set_module('mxnet.ndarray.numpy') @@ -5864,3 +5864,60 @@ def where(condition, x=None, y=None): return nonzero(condition) else: return _npi.where(condition, x, y, out=None) + + +@set_module('mxnet.ndarray.numpy') +def bincount(x, weights=None, minlength=0): + """ + Count number of occurrences of each value in array of non-negative ints. + + Parameters + ---------- + x : ndarray + input array, 1 dimension, nonnegative ints. + weights: ndarray + input weigths same shape as x. (Optional) + minlength: int + A minimum number of bins for the output. (Optional) + + Returns + -------- + out : ndarray + the result of binning the input array. The length of out is equal to amax(x)+1. + + Raises + -------- + Value Error + If the input is not 1-dimensional, or contains elements with negative values, + or if minlength is negative + TypeError + If the type of the input is float or complex. + + Examples + -------- + >>> np.bincount(np.arange(5)) + array([1, 1, 1, 1, 1]) + >>> np.bincount(np.array([0, 1, 1, 3, 2, 1, 7])) + array([1, 3, 1, 1, 0, 0, 0, 1]) + + >>> x = np.array([0, 1, 1, 3, 2, 1, 7, 23]) + >>> np.bincount(x).size == np.amax(x)+1 + True + + >>> np.bincount(np.arange(5, dtype=float)) + Traceback (most recent call last): + File "", line 1, in + TypeError: array cannot be safely cast to required type + + >>> w = np.array([0.3, 0.5, 0.2, 0.7, 1., -0.6]) # weights + >>> x = np.array([0, 1, 1, 2, 2, 2]) + >>> np.bincount(x, weights=w) + array([ 0.3, 0.7, 1.1]) + """ + if not isinstance(x, NDArray): + raise TypeError("Input data should be NDarray") + if minlength < 0: + raise ValueError("Minlength value should greater than 0") + if weights is None: + return _npi.bincount(x, minlength=minlength, has_weights=False) + return _npi.bincount(x, weights=weights, minlength=minlength, has_weights=True) diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index aa0762bf0e3f..4c8308ef364e 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -59,7 +59,7 @@ 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory', - 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where'] + 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where', 'bincount'] # Return code for dispatching indexing function call _NDARRAY_UNSUPPORTED_INDEXING = -1 @@ -7840,3 +7840,54 @@ def where(condition, x=None, y=None): [ 0., 3., -1.]]) """ return _mx_nd_np.where(condition, x, y) + + +@set_module('mxnet.numpy') +def bincount(x, weights=None, minlength=0): + """ + Count number of occurrences of each value in array of non-negative ints. + + Parameters + ---------- + x : ndarray + input array, 1 dimension, nonnegative ints. + weights: ndarray + input weigths same shape as x. (Optional) + minlength: int + A minimum number of bins for the output. (Optional) + + Returns + -------- + out : ndarray + the result of binning the input array. The length of out is equal to amax(x)+1. + + Raises + -------- + Value Error + If the input is not 1-dimensional, or contains elements with negative values, + or if minlength is negative + TypeError + If the type of the input is float or complex. + + Examples + -------- + >>> np.bincount(np.arange(5)) + array([1, 1, 1, 1, 1]) + >>> np.bincount(np.array([0, 1, 1, 3, 2, 1, 7])) + array([1, 3, 1, 1, 0, 0, 0, 1]) + + >>> x = np.array([0, 1, 1, 3, 2, 1, 7, 23]) + >>> np.bincount(x).size == np.amax(x)+1 + True + + >>> np.bincount(np.arange(5, dtype=float)) + Traceback (most recent call last): + File "", line 1, in + TypeError: array cannot be safely cast to required type + + >>> w = np.array([0.3, 0.5, 0.2, 0.7, 1., -0.6]) # weights + >>> x = np.array([0, 1, 1, 2, 2, 2]) + >>> np.bincount(x, weights=w) + array([ 0.3, 0.7, 1.1]) + """ + return _mx_nd_np.bincount(x, weights=weights, minlength=minlength) diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index 1a238ec2c7c7..9b2d20cca0c1 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -145,6 +145,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'resize', 'where', 'full_like', + 'bincount' ] diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 4b06bbec7cae..18cf5d15be92 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -48,7 +48,7 @@ 'blackman', 'flip', 'around', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', - 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where'] + 'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where', 'bincount'] @set_module('mxnet.symbol.numpy') @@ -5420,4 +5420,38 @@ def load_json(json_str): return _Symbol(handle) +@set_module('mxnet.symbol.numpy') +def bincount(x, weights=None, minlength=0): + """ + Count number of occurrences of each value in array of non-negative ints. + + Parameters + ---------- + x : _Symbol + input data + weights: _Symbol + input weigths same shape as x. (Optional) + minlength: int + A minimum number of bins for the output. (Optional) + + Returns + -------- + out : _Symbol + the result of binning the input data. The length of out is equal to amax(x)+1. + + Raises: + -------- + Value Error + If the input is not 1-dimensional, or contains elements with negative values, + or if minlength is negative + TypeError + If the type of the input is float or complex. + """ + if minlength < 0: + raise ValueError("Minlength value should greater than 0") + if weights is None: + return _npi.bincount(x, minlength=minlength, has_weights=False) + return _npi.bincount(x, weights=weights, minlength=minlength, has_weights=True) + + _set_np_symbol_class(_Symbol) diff --git a/src/operator/numpy/np_bincount_op-inl.h b/src/operator/numpy/np_bincount_op-inl.h new file mode 100644 index 000000000000..254ea8fdec22 --- /dev/null +++ b/src/operator/numpy/np_bincount_op-inl.h @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_bicount_op-inl.h + * \brief numpy compatible bincount operator + */ +#ifndef MXNET_OPERATOR_NUMPY_NP_BINCOUNT_OP_INL_H_ +#define MXNET_OPERATOR_NUMPY_NP_BINCOUNT_OP_INL_H_ + +#include +#include +#include +#include "../mshadow_op.h" +#include "../mxnet_op.h" +#include "../operator_common.h" +#include "../elemwise_op_common.h" +#include "np_broadcast_reduce_op.h" + +namespace mxnet { +namespace op { + +struct NumpyBincountParam : public dmlc::Parameter { + int minlength; + bool has_weights; + DMLC_DECLARE_PARAMETER(NumpyBincountParam) { + DMLC_DECLARE_FIELD(minlength) + .set_default(0) + .describe("A minimum number of bins for the output array" + "If minlength is specified, there will be at least this" + "number of bins in the output array"); + DMLC_DECLARE_FIELD(has_weights) + .set_default(false) + .describe("Determine whether Bincount has weights."); + } +}; + +inline bool NumpyBincountType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const NumpyBincountParam& param = nnvm::get(attrs.parsed); + if (!param.has_weights) { + return ElemwiseType<1, 1>(attrs, in_attrs, out_attrs) && in_attrs->at(0) != -1; + } else { + CHECK_EQ(out_attrs->size(), 1U); + CHECK_EQ(in_attrs->size(), 2U); + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1)); + TYPE_ASSIGN_CHECK(*in_attrs, 1, out_attrs->at(0)); + return out_attrs->at(0) != -1 && in_attrs->at(0) != -1; + } +} + +inline bool NumpyBincountStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector *in_attrs, + std::vector *out_attrs) { + const NumpyBincountParam& param = nnvm::get(attrs.parsed); + if (param.has_weights) { + CHECK_EQ(in_attrs->size(), 2U); + } else { + CHECK_EQ(in_attrs->size(), 1U); + } + CHECK_EQ(out_attrs->size(), 1U); + for (int &attr : *in_attrs) { + CHECK_EQ(attr, kDefaultStorage) << "Only default storage is supported"; + } + for (int &attr : *out_attrs) { + attr = kDefaultStorage; + } + *dispatch_mode = DispatchMode::kFComputeEx; + return true; +} + +template +void NumpyBincountForwardImpl(const OpContext &ctx, + const NDArray &data, + const NDArray &weights, + const NDArray &out, + const size_t &data_n, + const int &minlength); + +template +void NumpyBincountForwardImpl(const OpContext &ctx, + const NDArray &data, + const NDArray &out, + const size_t &data_n, + const int &minlength); + +template +void NumpyBincountForward(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + CHECK_GE(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + CHECK(req[0] == kWriteTo); + const NumpyBincountParam& param = nnvm::get(attrs.parsed); + const bool has_weights = param.has_weights; + const int minlength = param.minlength; + const NDArray &data = inputs[0]; + const NDArray &out = outputs[0]; + CHECK_LE(data.shape().ndim(), 1U) << "Input only accept 1d array"; + CHECK(!common::is_float(data.dtype())) <<"Input data should be int type"; + size_t N = data.shape().Size(); + if (N == 0) { + mshadow::Stream *stream = ctx.get_stream(); + mxnet::TShape s(1, minlength); + const_cast(out).Init(s); + MSHADOW_TYPE_SWITCH(out.dtype(), OType, { + mxnet_op::Kernel::Launch( + stream, minlength, out.data().dptr()); + }); + } else { + if (has_weights) { + CHECK_EQ(inputs.size(), 2U); + const NDArray &weights = inputs[1]; + CHECK_EQ(data.shape(), weights.shape()) << "weights should has same size as input"; + NumpyBincountForwardImpl(ctx, data, weights, out, N, minlength); + } else { + NumpyBincountForwardImpl(ctx, data, out, N, minlength); + } + } +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_NP_BINCOUNT_OP_INL_H_ diff --git a/src/operator/numpy/np_bincount_op.cc b/src/operator/numpy/np_bincount_op.cc new file mode 100644 index 000000000000..6256db176977 --- /dev/null +++ b/src/operator/numpy/np_bincount_op.cc @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_bicount_op.cc + * \brief numpy compatible bincount operator CPU registration + */ + +#include "./np_bincount_op-inl.h" + +namespace mxnet { +namespace op { + +void BinNumberCount(const NDArray& data, const int& minlength, + const NDArray& out, const size_t& N) { + int bin = minlength; + MSHADOW_TYPE_SWITCH(data.dtype(), DType, { + DType* data_ptr = data.data().dptr(); + for (size_t i = 0; i < N; i++) { + CHECK_GE(data_ptr[i], 0) << "input should be nonnegative number"; + if (data_ptr[i] + 1 > bin) { + bin = data_ptr[i] + 1; + } + } + }); // bin number = max(max(data) + 1, minlength) + mxnet::TShape s(1, bin); + const_cast(out).Init(s); // set the output shape forcefully +} + +template +void BincountCpuWeights(const DType* data, const OType* weights, + OType* out, const size_t& data_n) { + for (size_t i = 0; i < data_n; i++) { + int target = data[i]; + out[target] += weights[i]; + } +} + +template +void BincountCpu(const DType* data, OType* out, const size_t& data_n) { + for (size_t i = 0; i < data_n; i++) { + int target = data[i]; + out[target] += 1; + } +} + +template<> +void NumpyBincountForwardImpl(const OpContext &ctx, + const NDArray &data, + const NDArray &weights, + const NDArray &out, + const size_t &data_n, + const int &minlength) { + using namespace mxnet_op; + BinNumberCount(data, minlength, out, data_n); + mshadow::Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(data.dtype(), DType, { + MSHADOW_TYPE_SWITCH(weights.dtype(), OType, { + size_t out_size = out.shape()[0]; + Kernel::Launch(s, out_size, out.data().dptr()); + BincountCpuWeights(data.data().dptr(), weights.data().dptr(), + out.data().dptr(), data_n); + }); + }); +} + +template<> +void NumpyBincountForwardImpl(const OpContext &ctx, + const NDArray &data, + const NDArray &out, + const size_t &data_n, + const int &minlength) { + using namespace mxnet_op; + BinNumberCount(data, minlength, out, data_n); + mshadow::Stream *s = ctx.get_stream(); + MSHADOW_TYPE_SWITCH(data.dtype(), DType, { + MSHADOW_TYPE_SWITCH(out.dtype(), OType, { + size_t out_size = out.shape()[0]; + Kernel::Launch(s, out_size, out.data().dptr()); + BincountCpu(data.data().dptr(), out.data().dptr(), data_n); + }); + }); +} + +DMLC_REGISTER_PARAMETER(NumpyBincountParam); + +NNVM_REGISTER_OP(_npi_bincount) +.set_attr_parser(ParamParser) +.set_num_inputs([](const NodeAttrs& attrs) { + const NumpyBincountParam& params = + nnvm::get(attrs.parsed); + return params.has_weights? 2 : 1; + }) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + const NumpyBincountParam& params = + nnvm::get(attrs.parsed); + return params.has_weights ? + std::vector{"data", "weights"} : + std::vector{"data"}; + }) +.set_attr("FResourceRequest", +[](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; +}) +.set_attr("FInferType", NumpyBincountType) +.set_attr("FInferStorageType", NumpyBincountStorageType) +.set_attr("FComputeEx", NumpyBincountForward) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("data", "NDArray-or-Symbol", "Data") +.add_argument("weights", "NDArray-or-Symbol", "Weights") +.add_arguments(NumpyBincountParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/np_bincount_op.cu b/src/operator/numpy/np_bincount_op.cu new file mode 100644 index 000000000000..ed1f90f00c16 --- /dev/null +++ b/src/operator/numpy/np_bincount_op.cu @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_bicount_op.cu + * \brief numpy compatible bincount operator GPU registration + */ + +#include "./np_bincount_op-inl.h" +#include +#include +#include "../tensor/util/tensor_util-inl.cuh" +#include "../tensor/util/tensor_util-inl.h" + +namespace mxnet { +namespace op { + +struct BincountFusedKernel { + template + static MSHADOW_XINLINE void Map(int i, const DType* data, OType* out) { + int idx = data[i]; + atomicAdd(&out[idx], 1); + } + + template + static MSHADOW_XINLINE void Map(int i, const DType* data, const OType* weights, + OType* out) { + int idx = data[i]; + atomicAdd(&out[idx], weights[i]); + } +}; + +struct is_valid_check { + template + MSHADOW_XINLINE static void Map(int i, char* invalid_ptr, const DType* data) { + if (data[i] < 0) *invalid_ptr = 1; + } +}; + +template +bool CheckInvalidInput(mshadow::Stream *s, const DType *data, const size_t& data_size, + char* is_valid_ptr) { + using namespace mxnet_op; + int32_t is_valid = 0; + Kernel::Launch(s, 1, is_valid_ptr); + Kernel::Launch(s, data_size, is_valid_ptr, data); + CUDA_CALL(cudaMemcpyAsync(&is_valid, is_valid_ptr, sizeof(char), + cudaMemcpyDeviceToHost, mshadow::Stream::GetStream(s))); + CUDA_CALL(cudaStreamSynchronize(mshadow::Stream::GetStream(s))); + return is_valid == 0; +} + +template<> +void NumpyBincountForwardImpl(const OpContext &ctx, + const NDArray &data, + const NDArray &weights, + const NDArray &out, + const size_t &data_n, + const int &minlength) { + using namespace mxnet_op; + mshadow::Stream *s = ctx.get_stream(); + + MXNET_NO_FLOAT16_TYPE_SWITCH(data.dtype(), DType, { + DType* h_ptr; + DType* d_ptr; + int bin = minlength; + d_ptr = data.data().dptr(); + Tensor workspace = ctx.requested[0] + .get_space_typed(Shape1(1), s); + char* is_valid_ptr = reinterpret_cast(workspace.dptr_); + bool is_valid = CheckInvalidInput(s, d_ptr, data_n, is_valid_ptr); + CHECK(is_valid) << "Input should be nonnegative number"; // check invalid input + + h_ptr = reinterpret_cast(malloc(data_n*sizeof(DType))); + CUDA_CALL(cudaMemcpyAsync(h_ptr, d_ptr, data_n*sizeof(DType), cudaMemcpyDeviceToHost, + mshadow::Stream::GetStream(s))); + CUDA_CALL(cudaStreamSynchronize(mshadow::Stream::GetStream(s))); + for (size_t i = 0; i < data_n; i++) { + if (h_ptr[i] + 1 > bin) bin = h_ptr[i] + 1; + } + free(h_ptr); + mxnet::TShape s(1, bin); + const_cast(out).Init(s); // set the output shape forcefully + }); + + MSHADOW_TYPE_SWITCH(data.dtype(), DType, { + MSHADOW_TYPE_SWITCH(weights.dtype(), OType, { + size_t out_size = out.shape().Size(); + Kernel::Launch(s, out_size, out.data().dptr()); + Kernel::Launch( + s, data_n, data.data().dptr(), weights.data().dptr(), + out.data().dptr()); + }); + }); +} + +template<> +void NumpyBincountForwardImpl(const OpContext &ctx, + const NDArray &data, + const NDArray &out, + const size_t &data_n, + const int &minlength) { + using namespace mxnet_op; + mshadow::Stream *s = ctx.get_stream(); + + MXNET_NO_FLOAT16_TYPE_SWITCH(data.dtype(), DType, { + DType* h_ptr; + DType* d_ptr; + int bin = minlength; + d_ptr = data.data().dptr(); + Tensor workspace = ctx.requested[0] + .get_space_typed(Shape1(1), s); + char* is_valid_ptr = reinterpret_cast(workspace.dptr_); + bool is_valid = CheckInvalidInput(s, d_ptr, data_n, is_valid_ptr); + CHECK(is_valid) << "Input should be nonnegative number"; // check invalid input + + h_ptr = reinterpret_cast(malloc(data_n*sizeof(DType))); + CUDA_CALL(cudaMemcpyAsync(h_ptr, d_ptr, data_n*sizeof(DType), cudaMemcpyDeviceToHost, + mshadow::Stream::GetStream(s))); + CUDA_CALL(cudaStreamSynchronize(mshadow::Stream::GetStream(s))); + for (size_t i = 0; i < data_n; i++) { + if (h_ptr[i] + 1 > bin) bin = h_ptr[i] + 1; + } + free(h_ptr); + mxnet::TShape s(1, bin); + const_cast(out).Init(s); // set the output shape forcefully + }); + + MSHADOW_TYPE_SWITCH(data.dtype(), DType, { + MSHADOW_TYPE_SWITCH(out.dtype(), OType, { + size_t out_size = out.shape().Size(); + Kernel::Launch(s, out_size, out.data().dptr()); + Kernel::Launch( + s, data_n, data.data().dptr(), out.data().dptr()); + }); + }); +} + +NNVM_REGISTER_OP(_npi_bincount) +.set_attr("FComputeEx", NumpyBincountForward); + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 4c4e8b90eca9..a147dabc4949 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -67,6 +67,23 @@ def _add_workload_unravel_index(): OpArgMngr.add_workload('unravel_index', np.array([],dtype=_np.int64), (10, 3, 5)) OpArgMngr.add_workload('unravel_index', np.array([3], dtype=_np.int32), (2,2)) +def _add_workload_bincount(): + y = np.arange(4).astype(int) + y1 = np.array([1, 5, 2, 4, 1], dtype=_np.int64) + y2 = np.array((), dtype=_np.int8) + w = np.array([0.2, 0.3, 0.5, 0.1]) + w1 = np.array([0.2, 0.3, 0.5, 0.1, 0.2]) + + OpArgMngr.add_workload('bincount', y) + OpArgMngr.add_workload('bincount', y1) + OpArgMngr.add_workload('bincount', y, w) + OpArgMngr.add_workload('bincount', y1, w1) + OpArgMngr.add_workload('bincount', y1, w1, 8) + OpArgMngr.add_workload('bincount', y, minlength=3) + OpArgMngr.add_workload('bincount', y, minlength=8) + OpArgMngr.add_workload('bincount', y2, minlength=0) + OpArgMngr.add_workload('bincount', y2, minlength=5) + def _add_workload_diag(): def get_mat(n): @@ -1372,6 +1389,7 @@ def _prepare_workloads(): _add_workload_around() _add_workload_argsort() _add_workload_append() + _add_workload_bincount() _add_workload_broadcast_arrays(array_pool) _add_workload_broadcast_to() _add_workload_clip() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 078e37fc4146..7c998a8f9857 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -5491,6 +5491,56 @@ def hybrid_forward(self, F, a): assert_almost_equal(elem_mx.asnumpy(), elem_np, rtol=rtol, atol=atol) +@with_seed() +@use_np +def test_np_bincount(): + class TestBincount(HybridBlock): + def __init__(self, minlength=0): + super(TestBincount, self).__init__() + self._minlength = minlength + + def hybrid_forward(self, F, a): + return F.np.bincount(a, None, self._minlength) + + class TestBincountWeights(HybridBlock): + def __init__(self, minlength=0): + super(TestBincountWeights, self).__init__() + self._minlength = minlength + + def hybrid_forward(self, F, a, weights): + return F.np.bincount(a, weights, self._minlength) + + dtypes = [np.int8, np.uint8, np.int32, np.int64] + weight_types = [np.int32, np.int64, np.float16, np.float32, np.float64] + shapes = [(), (5,), (10,), (15,), (20,), (30,), (50,)] + min_lengths = [0, 5, 20, 50] + has_weights = [True, False] + combinations = itertools.product([True, False], shapes, dtypes, weight_types, has_weights, min_lengths) + for hybridize, shape, dtype, weight_type, has_weight, minlength in combinations: + rtol = 1e-2 if weight_type == np.float16 else 1e-3 + atol = 1e-4 if weight_type == np.float16 else 1e-5 + if shape != (): + data = np.random.uniform(0, 10, size=shape).astype(dtype) + weights = np.random.uniform(0, 10, size=shape).astype(weight_type) if has_weight else None + else: + data = np.array(()).astype(dtype) + weights = np.array(()).astype(weight_type) if has_weight else None + weights_np = weights.asnumpy() if has_weight else None + test_bincount = TestBincountWeights(minlength) if has_weight else TestBincount(minlength) + if hybridize: + test_bincount.hybridize() + mx_out = test_bincount(data, weights) if has_weight else test_bincount(data) + np_out = _np.bincount(data.asnumpy(), weights_np, minlength) + assert mx_out.shape == np_out.shape + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) + # No backward operation for operator bincount at this moment + + # Test imperative once again + mx_out = np.bincount(data, weights, minlength) + np_out = _np.bincount(data.asnumpy(), weights_np, minlength) + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=rtol, atol=atol) + + if __name__ == '__main__': import nose nose.runmodule()