Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
numpy-compatible split upstream
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Aug 13, 2019
1 parent 39bf4e0 commit 5acf58d
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 3 deletions.
52 changes: 51 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from ..ndarray import NDArray

__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange']
'linspace', 'expand_dims', 'tile', 'arange', 'split']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -420,6 +420,7 @@ def tensordot(a, b, axes=2):
return _npi.tensordot(a, b, a_axes_summed, b_axes_summed)


@set_module('mxnet.ndarray.numpy')
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, ctx=None): # pylint: disable=too-many-arguments
r"""
Return evenly spaced numbers over a specified interval.
Expand Down Expand Up @@ -632,3 +633,52 @@ def tile(A, reps):
"""
return _unary_func_helper(A, _npi.tile, _np.tile, reps=reps)


@set_module('mxnet.ndarray.numpy')
def split(ary, indices_or_sections, axis=0):
"""Split an array into multiple sub-arrays.
Parameters
----------
ary : ndarray
Array to be divided into sub-arrays.
indices_or_sections : int or 1-D array
If `indices_or_sections` is an integer, N, the array will be divided
into N equal arrays along `axis`. If such a split is not possible,
an error is raised.
If `indices_or_sections` is a 1-D array of sorted integers, the entries
indicate where along `axis` the array is split. For example,
``[2, 3]`` would, for ``axis=0``, result in
- ary[:2]
- ary[2:3]
- ary[3:]
If an index exceeds the dimension of the array along `axis`,
an empty sub-array is returned correspondingly.
axis : int, optional
The axis along which to split, default is 0.
Returns
-------
sub-arrays : list of ndarrays
A list of sub-arrays.
Raises
------
ValueError
If `indices_or_sections` is given as an integer, but
a split does not result in equal division.
"""
indices = []
axis_size = ary.shape[axis]
if isinstance(indices_or_sections, int):
sections = indices_or_sections
if axis_size % sections:
raise ValueError('array split does not result in an equal division')
section_size = int(axis_size / sections)
indices = [i * section_size for i in range(sections)]
elif isinstance(indices_or_sections, tuple):
indices = [0] + list(indices_or_sections)
else:
raise ValueError('indices_or_sections must either int or tuple of ints')
ret = _npi.split(ary, indices, axis, False)
if not isinstance(ret, list):
return [ret]
return ret
36 changes: 35 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from ..ndarray.numpy import _internal as _npi

__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'add', 'subtract', 'multiply', 'divide',
'mod', 'power', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange']
'mod', 'power', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split']


# This function is copied from ndarray.py since pylint
Expand Down Expand Up @@ -1606,6 +1606,7 @@ def tensordot(a, b, axes=2):
return _mx_nd_np.tensordot(a, b, axes)


@set_module('mxnet.numpy')
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, ctx=None): # pylint: disable=too-many-arguments
r"""
Return evenly spaced numbers over a specified interval.
Expand Down Expand Up @@ -1819,3 +1820,36 @@ def arange(start, stop=None, step=1, dtype=None, ctx=None):
than `stop`.
"""
return _mx_nd_np.arange(start, stop, step, dtype, ctx)


@set_module('mxnet.numpy')
def split(ary, indices_or_sections, axis=0):
"""Split an array into multiple sub-arrays.
Parameters
----------
ary : ndarray
Array to be divided into sub-arrays.
indices_or_sections : int or 1-D array
If `indices_or_sections` is an integer, N, the array will be divided
into N equal arrays along `axis`. If such a split is not possible,
an error is raised.
If `indices_or_sections` is a 1-D array of sorted integers, the entries
indicate where along `axis` the array is split. For example,
``[2, 3]`` would, for ``axis=0``, result in
- ary[:2]
- ary[2:3]
- ary[3:]
If an index exceeds the dimension of the array along `axis`,
an empty sub-array is returned correspondingly.
axis : int, optional
The axis along which to split, default is 0.
Returns
-------
sub-arrays : list of ndarrays
A list of sub-arrays.
Raises
------
ValueError
If `indices_or_sections` is given as an integer, but
a split does not result in equal division."""
return _mx_nd_np.split(ary, indices_or_sections, axis=axis)
45 changes: 44 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from . import _internal as _npi

__all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange']
'linspace', 'expand_dims', 'tile', 'arange', 'split']


def _num_outputs(sym):
Expand Down Expand Up @@ -1063,6 +1063,7 @@ def tensordot(a, b, axes=2):
return _npi.tensordot(a, b, a_axes_summed, b_axes_summed)


@set_module('mxnet.symbol.numpy')
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, ctx=None): # pylint: disable=too-many-arguments
r"""
Return evenly spaced numbers over a specified interval.
Expand Down Expand Up @@ -1269,4 +1270,46 @@ def arange(start, stop=None, step=1, dtype=None, ctx=None):
return _npi.arange(start=start, stop=stop, step=step, dtype=dtype, ctx=ctx)


@set_module('mxnet.symbol.numpy')
def split(ary, indices_or_sections, axis=0):
"""Split an array into multiple sub-arrays.
Parameters
----------
ary : ndarray
Array to be divided into sub-arrays.
indices_or_sections : int or 1-D array
If `indices_or_sections` is an integer, N, the array will be divided
into N equal arrays along `axis`. If such a split is not possible,
an error is raised.
If `indices_or_sections` is a 1-D array of sorted integers, the entries
indicate where along `axis` the array is split. For example,
``[2, 3]`` would, for ``axis=0``, result in
- ary[:2]
- ary[2:3]
- ary[3:]
If an index exceeds the dimension of the array along `axis`,
an empty sub-array is returned correspondingly.
axis : int, optional
The axis along which to split, default is 0.
Returns
-------
sub-arrays : list of ndarrays
A list of sub-arrays.
Raises
------
ValueError
If `indices_or_sections` is given as an integer, but
a split does not result in equal division."""
indices = []
sections = 0
if isinstance(indices_or_sections, int):
sections = indices_or_sections
elif isinstance(indices_or_sections, tuple):
indices = [0] + list(indices_or_sections)
else:
raise ValueError('indices_or_sections must either int or tuple of ints')
ret = _npi.split(ary, indices, axis, False, sections)
return ret


_set_np_symbol_class(_Symbol)
53 changes: 53 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,6 +776,59 @@ def hybrid_forward(self, F, x):
assert same(mx_out.asnumpy(), np_out)


@with_seed()
@use_np
def test_np_split():
class TestSplit(HybridBlock):
def __init__(self, indices_or_sections, axis=None):
super(TestSplit, self).__init__()
self._axis = axis
self._indices_or_sections = indices_or_sections

def hybrid_forward(self, F, a, *args, **kwargs):
return F.np.split(a, indices_or_sections=self._indices_or_sections,
axis=self._axis)

def get_indices(axis_size):
if axis_size is 0:
axis_size = random.randint(3, 6)
samples = random.randint(1, axis_size - 1)
indices = sorted(random.sample([i for i in range(1, axis_size)], samples))
indices = tuple(indices)
return indices

dim = random.randint(0, 3)
shape = [0] + [random.randint(2, 4) for i in range(dim)]
for hybridize in [True, False]:
for axis in range(len(shape)):
indices = get_indices(shape[axis])
sections = 7 if shape[axis] is 0 else shape[axis]
for indices_or_sections in [indices, sections]:
# test gluon
test_split = TestSplit(axis=axis, indices_or_sections=indices_or_sections)
if hybridize:
test_split.hybridize()

a = mx.nd.random.uniform(-1.0, 1.0, shape=shape).as_np_ndarray()
a.attach_grad()
expected_ret = _np.split(a.asnumpy(), indices_or_sections=indices_or_sections, axis=axis)
with mx.autograd.record():
y = test_split(a)
assert len(y) == len(expected_ret)
for mx_out, np_out in zip(y, expected_ret):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)

mx.autograd.backward(y)

assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5)

# test imperative
mx_outs = np.split(a, indices_or_sections=indices_or_sections, axis=axis)
np_outs = _np.split(a.asnumpy(), indices_or_sections=indices_or_sections, axis=axis)
for mx_out, np_out in zip(mx_outs, np_outs):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 5acf58d

Please sign in to comment.