Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
split_v2 (#13687)
Browse files Browse the repository at this point in the history
  • Loading branch information
HyperZealot authored and szha committed Jan 23, 2019
1 parent ce8b083 commit 45d1a1e
Show file tree
Hide file tree
Showing 7 changed files with 556 additions and 5 deletions.
61 changes: 60 additions & 1 deletion python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
"imdecode", "lesser", "lesser_equal", "logical_and", "logical_or", "logical_xor",
"maximum", "minimum", "moveaxis", "modulo", "multiply", "not_equal", "onehot_encode",
"power", "subtract", "true_divide", "waitall", "_new_empty_handle", "histogram",
"to_dlpack_for_read", "to_dlpack_for_write", "from_dlpack"]
"split_v2", "to_dlpack_for_read", "to_dlpack_for_write", "from_dlpack"]

_STORAGE_TYPE_UNDEFINED = -1
_STORAGE_TYPE_DEFAULT = 0
Expand Down Expand Up @@ -1133,6 +1133,14 @@ def split(self, *args, **kwargs):
"""
return op.split(self, *args, **kwargs)

def split_v2(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`split_v2`.
The arguments are the same as for :py:func:`split_v2`, with
this array as data.
"""
return split_v2(self, *args, **kwargs)

def slice(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`slice`.
Expand Down Expand Up @@ -3901,6 +3909,12 @@ def histogram(a, bins=10, range=None):
Values outside the range are ignored. The first element of the range must be less than or
equal to the second. range affects the automatic bin computation as well, the range will
be equally divided by the number of bins.
Returns
-------
NDArray
A created array.
"""

# pylint: disable= no-member, protected-access
Expand All @@ -3916,6 +3930,51 @@ def histogram(a, bins=10, range=None):
raise ValueError("bins argument should be either an integer or an NDArray")
# pylint: enable= no-member, protected-access, redefined-builtin

def split_v2(ary, indices_or_sections, axis=0, squeeze_axis=False):
"""Split an array into multiple sub-arrays.
Parameters
----------
ary : NDArray
Array to be divided into sub-arrays.
indices_or_sections : int or tuple of ints
If `indices_or_sections` is an integer, N, the array will be divided
into N equal arrays along `axis`. If such a split is not possible,
an error is raised.
If `indices_or_sections` is a 1-D array of sorted integers, the entries
indicate where along `axis` the array is split. For example,
``[2, 3]`` would, for ``axis=0``, result in
- ary[:2]
- ary[2:3]
- ary[3:]
If an index exceeds the dimension of the array along `axis`,
an empty sub-array is returned correspondingly.
axis : int, optional
The axis along which to split, default is 0.
squeeze_axis: boolean, optional
Whether to squeeze the axis of sub-arrays or not, only useful when size
of the sub-arrays are 1 on the `axis`. Default is False.
Returns
-------
NDArray
A created array.
"""
indices = []
axis_size = ary.shape[axis]
if isinstance(indices_or_sections, int):
sections = indices_or_sections
if axis_size % sections:
raise ValueError('array split does not result in an equal division')
section_size = int(axis_size / sections)
indices = [i * section_size for i in range(sections)]
elif isinstance(indices_or_sections, tuple):
indices = [0] + list(indices_or_sections)
else:
raise ValueError('indices_or_sections must either int or tuple of ints')
return _internal._split_v2(ary, indices, axis, squeeze_axis)

PyCapsuleDestructor = ctypes.CFUNCTYPE(None, ctypes.c_void_p)
_c_str_dltensor = c_str('dltensor')
_c_str_used_dltensor = c_str('used_dltensor')
Expand Down
55 changes: 54 additions & 1 deletion python/mxnet/symbol/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@

__all__ = ["Symbol", "var", "Variable", "Group", "load", "load_json",
"pow", "maximum", "minimum", "hypot", "eye", "zeros", "ones", "full", "arange",
"histogram"]
"histogram", "split_v2"]


class Symbol(SymbolBase):
Expand Down Expand Up @@ -1855,6 +1855,14 @@ def split(self, *args, **kwargs):
"""
return op.split(self, *args, **kwargs)

def split_v2(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`split_v2`.
The arguments are the same as for :py:func:`split_v2`, with
this array as data.
"""
return split_v2(self, *args, **kwargs)

def slice(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`slice`.
Expand Down Expand Up @@ -2958,6 +2966,11 @@ def histogram(a, bins=10, range=None, **kwargs):
Values outside the range are ignored. The first element of the range must be less than or
equal to the second. range affects the automatic bin computation as well, the range will
be equally divided by the number of bins.
Returns
-------
out : Symbol
The created Symbol
"""
if isinstance(bins, Symbol):
return _internal._histogram(data=a, bins=bins, **kwargs)
Expand All @@ -2967,4 +2980,44 @@ def histogram(a, bins=10, range=None, **kwargs):
return _internal._histogram(data=a, bin_cnt=bins, range=range, **kwargs)
raise ValueError("bins argument should be either an integer or an NDArray")

def split_v2(ary, indices_or_sections, axis=0, squeeze_axis=False):
"""Split an array into multiple sub-arrays.
Parameters
----------
ary : NDArray
Array to be divided into sub-arrays.
indices_or_sections : int or tuple of ints
If `indices_or_sections` is an integer, N, the array will be divided
into N equal arrays along `axis`. If such a split is not possible,
an error is raised.
If `indices_or_sections` is a 1-D array of sorted integers, the entries
indicate where along `axis` the array is split. For example,
``[2, 3]`` would, for ``axis=0``, result in
- ary[:2]
- ary[2:3]
- ary[3:]
If an index exceeds the dimension of the array along `axis`,
an empty sub-array is returned correspondingly.
axis : int, optional
The axis along which to split, default is 0.
squeeze_axis: boolean, optional
Whether to squeeze the axis of sub-arrays or not, only useful when size
of the sub-arrays are 1 on the `axis`. Default is False.
Returns
-------
out : Symbol
The created Symbol
"""
indices = []
sections = 0
if isinstance(indices_or_sections, int):
sections = indices_or_sections
elif isinstance(indices_or_sections, tuple):
indices = [0] + list(indices_or_sections)
else:
raise ValueError('indices_or_sections must either int or tuple of ints')
return _internal._split_v2(ary, indices, axis, squeeze_axis, sections)

_set_symbol_class(Symbol)
Loading

0 comments on commit 45d1a1e

Please sign in to comment.