Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
numpy-compatible histogram
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Sep 25, 2019
1 parent ab2214b commit 7fd7277
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 3 deletions.
49 changes: 48 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
'arctan2', 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs',
'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2',
'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
Expand Down Expand Up @@ -612,6 +612,53 @@ def tensordot(a, b, axes=2):
return _npi.tensordot(a, b, a_axes_summed, b_axes_summed)


@set_module('mxnet.ndarray.numpy')
def histogram(a, bins=10, range=None, normed=None, weights=None, density=None): # pylint: disable=too-many-arguments
"""
Compute the histogram of a set of data.
Parameters
----------
a : NDArray
Input data. The histogram is computed over the flattened array.
bins : int or NDArray
If `bins` is an int, it defines the number of equal-width
bins in the given range (10, by default). If `bins` is a
sequence, it defines a monotonically increasing array of bin edges,
including the rightmost edge, allowing for non-uniform bin widths.
.. versionadded:: 1.11.0
If `bins` is a string, it defines the method used to calculate the
optimal bin width, as defined by `histogram_bin_edges`.
range : (float, float)
The lower and upper range of the bins. Required when `bins` is an integer.
Values outside the range are ignored. The first element of the range must
be less than or equal to the second.
normed : bool, optional
Not supported yet, coming soon.
weights : array_like, optional
Not supported yet, coming soon.
density : bool, optional
Not supported yet, coming soon.
"""
if normed is True:
raise NotImplementedError("normed is not supported yet...")
if weights is not None:
raise NotImplementedError("weights is not supported yet...")
if density is True:
raise NotImplementedError("density is not supported yet...")
if isinstance(bins, numeric_types):
if range is None:
raise NotImplementedError("automatic range is not supported yet...")
return _npi.histogram(a, bin_cnt=bins, range=range)
if isinstance(bins, (list, tuple)):
raise NotImplementedError("array_like bins is not supported yet...")
if isinstance(bins, str):
raise NotImplementedError("string bins is not supported yet...")
if isinstance(bins, NDArray):
return _npi.histogram(a, bins=bins)
raise ValueError("np.histogram fails with", locals())


@set_module('mxnet.ndarray.numpy')
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, ctx=None): # pylint: disable=too-many-arguments
r"""
Expand Down
34 changes: 33 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@
'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log',
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
'tensordot', 'histogram', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
'stack', 'vstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices',
'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot',
'rad2deg', 'deg2rad', 'unique']


# Return code for dispatching indexing function call
_NDARRAY_UNSUPPORTED_INDEXING = -1
_NDARRAY_BASIC_INDEXING = 0
Expand Down Expand Up @@ -3444,6 +3445,37 @@ def tensordot(a, b, axes=2):
return _mx_nd_np.tensordot(a, b, axes)


@set_module('mxnet.numpy')
def histogram(a, bins=10, range=None, normed=None, weights=None, density=None): # pylint-disable=too-many-arguments
"""
Compute the histogram of a set of data.
Parameters
----------
a : NDArray
Input data. The histogram is computed over the flattened array.
bins : int or NDArray
If `bins` is an int, it defines the number of equal-width
bins in the given range (10, by default). If `bins` is a
sequence, it defines a monotonically increasing array of bin edges,
including the rightmost edge, allowing for non-uniform bin widths.
.. versionadded:: 1.11.0
If `bins` is a string, it defines the method used to calculate the
optimal bin width, as defined by `histogram_bin_edges`.
range : (float, float)
The lower and upper range of the bins. Required when `bins` is an integer.
Values outside the range are ignored. The first element of the range must
be less than or equal to the second.
normed : bool, optional
Not supported yet, coming soon.
weights : array_like, optional
Not supported yet, coming soon.
density : bool, optional
Not supported yet, coming soon.
"""
return _mx_nd_np.histogram(a, bins=bins, range=range, normed=normed, weights=weights, density=density)


@set_module('mxnet.numpy')
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, ctx=None): # pylint: disable=too-many-arguments
r"""
Expand Down
49 changes: 48 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'log10', 'sqrt', 'cbrt', 'abs', 'absolute', 'exp',
'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
Expand Down Expand Up @@ -1134,6 +1134,53 @@ def tensordot(a, b, axes=2):
return _npi.tensordot(a, b, a_axes_summed, b_axes_summed)


@set_module('mxnet.symbol.numpy')
def histogram(a, bins=10, range=None, normed=None, weights=None, density=None): # pylint: disable= too-many-arguments
"""
Compute the histogram of a set of data.
Parameters
----------
a : Symbol
Input data. The histogram is computed over the flattened array.
bins : int or Symbol
If `bins` is an int, it defines the number of equal-width
bins in the given range (10, by default). If `bins` is a
sequence, it defines a monotonically increasing array of bin edges,
including the rightmost edge, allowing for non-uniform bin widths.
.. versionadded:: 1.11.0
If `bins` is a string, it defines the method used to calculate the
optimal bin width, as defined by `histogram_bin_edges`.
range : (float, float)
The lower and upper range of the bins. Required when `bins` is an integer.
Values outside the range are ignored. The first element of the range must
be less than or equal to the second.
normed : bool, optional
Not supported yet, coming soon.
weights : array_like, optional
Not supported yet, coming soon.
density : bool, optional
Not supported yet, coming soon.
"""
if normed is True:
raise NotImplementedError("normed is not supported yet...")
if weights is not None:
raise NotImplementedError("weights is not supported yet...")
if density is True:
raise NotImplementedError("density is not supported yet...")
if isinstance(bins, numeric_types):
if range is None:
raise NotImplementedError("automatic range is not avaialble yet...")
return _npi.histogram(a, bin_cnt=bins, range=range)
if isinstance(bins, (list, tuple)):
raise NotImplementedError("array_like bins is not supported yet...")
if isinstance(bins, str):
raise NotImplementedError("string bins is not supported yet...")
if isinstance(bins, Symbol):
return _npi.histogram(a, bins)
raise ValueError("histogram fails with", locals())


@set_module('mxnet.symbol.numpy')
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, ctx=None): # pylint: disable=too-many-arguments
r"""
Expand Down
1 change: 1 addition & 0 deletions src/operator/tensor/histogram.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ void HistogramForwardImpl<cpu>(const OpContext& ctx,
DMLC_REGISTER_PARAMETER(HistogramParam);

NNVM_REGISTER_OP(_histogram)
.add_alias("_npi_histogram")
.describe(R"code(This operators implements the histogram function.
Example::
Expand Down
18 changes: 18 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1674,6 +1674,24 @@ def hybrid_forward(self, F, a):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


@with_seed()
@use_np
def test_np_histogram():
shapes = [(), (3, 4), (3, 0)]

for shape in shapes:
mx_a = np.random.uniform(0.0, 10.0, size=shape)
np_a = mx_a.asnumpy()
mx_bins = np.array([0.0, 1.0, 2.0, 3.0, 4.0, 5., 6., 7., 8., 9., 10.])
np_bins = mx_bins.asnumpy()
for bins, _range in [(20, (0.0, 10.0)), (mx_bins, None)]:
print(bins, _range)
mx_cnts, mx_bins = np.histogram(mx_a, bins=bins, range=_range)
np_cnts, np_bins = _np.histogram(np_a, bins=bins if isinstance(bins, mx.base.numeric_types) else bins.asnumpy(), range=_range)
assert_almost_equal(mx_cnts.asnumpy(), np_cnts, rtol=1e-3, atol=1e-5)
assert_almost_equal(mx_bins.asnumpy(), np_bins, rtol=1e-3, atol=1e-5)


@with_seed()
@use_np
def test_np_choice():
Expand Down

0 comments on commit 7fd7277

Please sign in to comment.