Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Op_bincount [Numpy] #16965

Merged
merged 1 commit into from
Dec 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
'blackman', 'flip', 'around', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less',
'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory',
'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where']
'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where', 'bincount']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -5864,3 +5864,60 @@ def where(condition, x=None, y=None):
return nonzero(condition)
else:
return _npi.where(condition, x, y, out=None)


@set_module('mxnet.ndarray.numpy')
def bincount(x, weights=None, minlength=0):
"""
Count number of occurrences of each value in array of non-negative ints.

Parameters
----------
x : ndarray
input array, 1 dimension, nonnegative ints.
weights: ndarray
input weigths same shape as x. (Optional)
minlength: int
A minimum number of bins for the output. (Optional)

Returns
--------
out : ndarray
the result of binning the input array. The length of out is equal to amax(x)+1.

Raises
--------
Value Error
If the input is not 1-dimensional, or contains elements with negative values,
or if minlength is negative
TypeError
If the type of the input is float or complex.

Examples
--------
>>> np.bincount(np.arange(5))
array([1, 1, 1, 1, 1])
>>> np.bincount(np.array([0, 1, 1, 3, 2, 1, 7]))
array([1, 3, 1, 1, 0, 0, 0, 1])

>>> x = np.array([0, 1, 1, 3, 2, 1, 7, 23])
>>> np.bincount(x).size == np.amax(x)+1
True

>>> np.bincount(np.arange(5, dtype=float))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: array cannot be safely cast to required type

>>> w = np.array([0.3, 0.5, 0.2, 0.7, 1., -0.6]) # weights
>>> x = np.array([0, 1, 1, 2, 2, 2])
>>> np.bincount(x, weights=w)
array([ 0.3, 0.7, 1.1])
"""
if not isinstance(x, NDArray):
raise TypeError("Input data should be NDarray")
if minlength < 0:
raise ValueError("Minlength value should greater than 0")
if weights is None:
return _npi.bincount(x, minlength=minlength, has_weights=False)
return _npi.bincount(x, weights=weights, minlength=minlength, has_weights=True)
53 changes: 52 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take',
'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal',
'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'nonzero', 'shares_memory',
'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where']
'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where', 'bincount']

# Return code for dispatching indexing function call
_NDARRAY_UNSUPPORTED_INDEXING = -1
Expand Down Expand Up @@ -7840,3 +7840,54 @@ def where(condition, x=None, y=None):
[ 0., 3., -1.]])
"""
return _mx_nd_np.where(condition, x, y)


@set_module('mxnet.numpy')
def bincount(x, weights=None, minlength=0):
"""
Count number of occurrences of each value in array of non-negative ints.

Parameters
----------
x : ndarray
input array, 1 dimension, nonnegative ints.
weights: ndarray
input weigths same shape as x. (Optional)
minlength: int
A minimum number of bins for the output. (Optional)

Returns
--------
out : ndarray
the result of binning the input array. The length of out is equal to amax(x)+1.

Raises
--------
Value Error
If the input is not 1-dimensional, or contains elements with negative values,
or if minlength is negative
TypeError
If the type of the input is float or complex.

Examples
--------
>>> np.bincount(np.arange(5))
array([1, 1, 1, 1, 1])
>>> np.bincount(np.array([0, 1, 1, 3, 2, 1, 7]))
array([1, 3, 1, 1, 0, 0, 0, 1])

>>> x = np.array([0, 1, 1, 3, 2, 1, 7, 23])
>>> np.bincount(x).size == np.amax(x)+1
True

>>> np.bincount(np.arange(5, dtype=float))
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: array cannot be safely cast to required type

>>> w = np.array([0.3, 0.5, 0.2, 0.7, 1., -0.6]) # weights
>>> x = np.array([0, 1, 1, 2, 2, 2])
>>> np.bincount(x, weights=w)
array([ 0.3, 0.7, 1.1])
"""
return _mx_nd_np.bincount(x, weights=weights, minlength=minlength)
1 change: 1 addition & 0 deletions python/mxnet/numpy_dispatch_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs):
'resize',
'where',
'full_like',
'bincount'
]


Expand Down
36 changes: 35 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
'blackman', 'flip', 'around', 'hypot', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less',
'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory',
'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where']
'may_share_memory', 'diff', 'resize', 'nan_to_num', 'where', 'bincount']


@set_module('mxnet.symbol.numpy')
Expand Down Expand Up @@ -5420,4 +5420,38 @@ def load_json(json_str):
return _Symbol(handle)


@set_module('mxnet.symbol.numpy')
Tommliu marked this conversation as resolved.
Show resolved Hide resolved
def bincount(x, weights=None, minlength=0):
"""
Count number of occurrences of each value in array of non-negative ints.

Parameters
----------
x : _Symbol
input data
weights: _Symbol
input weigths same shape as x. (Optional)
minlength: int
A minimum number of bins for the output. (Optional)

Returns
--------
out : _Symbol
the result of binning the input data. The length of out is equal to amax(x)+1.

Raises:
--------
Value Error
If the input is not 1-dimensional, or contains elements with negative values,
or if minlength is negative
TypeError
If the type of the input is float or complex.
"""
if minlength < 0:
raise ValueError("Minlength value should greater than 0")
if weights is None:
return _npi.bincount(x, minlength=minlength, has_weights=False)
return _npi.bincount(x, weights=weights, minlength=minlength, has_weights=True)


_set_np_symbol_class(_Symbol)
147 changes: 147 additions & 0 deletions src/operator/numpy/np_bincount_op-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file np_bicount_op-inl.h
* \brief numpy compatible bincount operator
*/
#ifndef MXNET_OPERATOR_NUMPY_NP_BINCOUNT_OP_INL_H_
#define MXNET_OPERATOR_NUMPY_NP_BINCOUNT_OP_INL_H_

#include <mxnet/operator_util.h>
#include <utility>
#include <vector>
#include "../mshadow_op.h"
#include "../mxnet_op.h"
#include "../operator_common.h"
#include "../elemwise_op_common.h"
#include "np_broadcast_reduce_op.h"

namespace mxnet {
namespace op {

struct NumpyBincountParam : public dmlc::Parameter<NumpyBincountParam> {
int minlength;
bool has_weights;
DMLC_DECLARE_PARAMETER(NumpyBincountParam) {
DMLC_DECLARE_FIELD(minlength)
.set_default(0)
.describe("A minimum number of bins for the output array"
"If minlength is specified, there will be at least this"
"number of bins in the output array");
DMLC_DECLARE_FIELD(has_weights)
.set_default(false)
.describe("Determine whether Bincount has weights.");
}
};

inline bool NumpyBincountType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
const NumpyBincountParam& param = nnvm::get<NumpyBincountParam>(attrs.parsed);
if (!param.has_weights) {
return ElemwiseType<1, 1>(attrs, in_attrs, out_attrs) && in_attrs->at(0) != -1;
} else {
CHECK_EQ(out_attrs->size(), 1U);
CHECK_EQ(in_attrs->size(), 2U);
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1));
TYPE_ASSIGN_CHECK(*in_attrs, 1, out_attrs->at(0));
return out_attrs->at(0) != -1 && in_attrs->at(0) != -1;
}
}

inline bool NumpyBincountStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
const NumpyBincountParam& param = nnvm::get<NumpyBincountParam>(attrs.parsed);
if (param.has_weights) {
CHECK_EQ(in_attrs->size(), 2U);
} else {
CHECK_EQ(in_attrs->size(), 1U);
}
CHECK_EQ(out_attrs->size(), 1U);
for (int &attr : *in_attrs) {
CHECK_EQ(attr, kDefaultStorage) << "Only default storage is supported";
}
for (int &attr : *out_attrs) {
attr = kDefaultStorage;
}
*dispatch_mode = DispatchMode::kFComputeEx;
return true;
}

template<typename xpu>
void NumpyBincountForwardImpl(const OpContext &ctx,
const NDArray &data,
const NDArray &weights,
const NDArray &out,
const size_t &data_n,
const int &minlength);

template<typename xpu>
void NumpyBincountForwardImpl(const OpContext &ctx,
const NDArray &data,
const NDArray &out,
const size_t &data_n,
const int &minlength);

template<typename xpu>
void NumpyBincountForward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
CHECK_GE(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
CHECK(req[0] == kWriteTo);
const NumpyBincountParam& param = nnvm::get<NumpyBincountParam>(attrs.parsed);
const bool has_weights = param.has_weights;
const int minlength = param.minlength;
const NDArray &data = inputs[0];
const NDArray &out = outputs[0];
CHECK_LE(data.shape().ndim(), 1U) << "Input only accept 1d array";
CHECK(!common::is_float(data.dtype())) <<"Input data should be int type";
size_t N = data.shape().Size();
if (N == 0) {
mshadow::Stream<xpu> *stream = ctx.get_stream<xpu>();
mxnet::TShape s(1, minlength);
const_cast<NDArray &>(out).Init(s);
MSHADOW_TYPE_SWITCH(out.dtype(), OType, {
mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch(
stream, minlength, out.data().dptr<OType>());
});
} else {
if (has_weights) {
CHECK_EQ(inputs.size(), 2U);
const NDArray &weights = inputs[1];
CHECK_EQ(data.shape(), weights.shape()) << "weights should has same size as input";
NumpyBincountForwardImpl<xpu>(ctx, data, weights, out, N, minlength);
} else {
NumpyBincountForwardImpl<xpu>(ctx, data, out, N, minlength);
}
}
}

} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_NUMPY_NP_BINCOUNT_OP_INL_H_
Loading