Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Numpy compatible vsplit; minor changes to split (#15983)
Browse files Browse the repository at this point in the history
  • Loading branch information
zoeygxy authored and haojin2 committed Oct 13, 2019
1 parent 08a4db7 commit 4b535a8
Show file tree
Hide file tree
Showing 4 changed files with 283 additions and 23 deletions.
94 changes: 85 additions & 9 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@
'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2',
'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'dstack',
'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal']
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'stack',
'vstack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var',
'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot',
'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot',
'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -823,7 +823,6 @@ def eye(N, M=None, k=0, dtype=_np.float32, **kwargs):
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, ctx=None): # pylint: disable=too-many-arguments
r"""
Return evenly spaced numbers over a specified interval.
Returns num evenly spaced samples, calculated over the interval [start, stop].
The endpoint of the interval can optionally be excluded.
Expand Down Expand Up @@ -2354,7 +2353,7 @@ def split(ary, indices_or_sections, axis=0):
----------
ary : ndarray
Array to be divided into sub-arrays.
indices_or_sections : int or 1-D array
indices_or_sections : int or 1-D python tuple, list or set.
If `indices_or_sections` is an integer, N, the array will be divided
into N equal arrays along `axis`. If such a split is not possible,
an error is raised.
Expand Down Expand Up @@ -2386,17 +2385,94 @@ def split(ary, indices_or_sections, axis=0):
raise ValueError('array split does not result in an equal division')
section_size = int(axis_size / sections)
indices = [i * section_size for i in range(sections)]
elif isinstance(indices_or_sections, tuple):
elif isinstance(indices_or_sections, (list, set, tuple)):
indices = [0] + list(indices_or_sections)
else:
raise ValueError('indices_or_sections must either int or tuple of ints')
raise ValueError('indices_or_sections must either int, or tuple / list / set of ints')
ret = _npi.split(ary, indices, axis, False)
if not isinstance(ret, list):
return [ret]
return ret
# pylint: enable=redefined-outer-name


@set_module('mxnet.ndarray.numpy')
def vsplit(ary, indices_or_sections):
r"""
vsplit(ary, indices_or_sections)
Split an array into multiple sub-arrays vertically (row-wise).
``vsplit`` is equivalent to ``split`` with `axis=0` (default): the array is always split
along the first axis regardless of the array dimension.
Parameters
----------
ary : ndarray
Array to be divided into sub-arrays.
indices_or_sections : int or 1 - D Python tuple, list or set.
If `indices_or_sections` is an integer, N, the array will be divided into N equal arrays
along axis 0. If such a split is not possible, an error is raised.
If `indices_or_sections` is a 1-D array of sorted integers, the entries indicate where
along axis 0 the array is split. For example, ``[2, 3]`` would result in
- ary[:2]
- ary[2:3]
- ary[3:]
If an index exceeds the dimension of the array along axis 0, an error will be thrown.
Returns
-------
sub-arrays : list of ndarrays
A list of sub-arrays.
See Also
--------
split : Split an array into multiple sub-arrays of equal size.
Notes
-------
This function differs from the original `numpy.degrees
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.degrees.html>`_ in
the following aspects:
- Currently parameter ``indices_or_sections`` does not support ndarray, but supports scalar,
tuple and list.
- In ``indices_or_sections``, if an index exceeds the dimension of the array along axis 0,
an error will be thrown.
Examples
--------
>>> x = np.arange(16.0).reshape(4, 4)
>>> x
array([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[ 12., 13., 14., 15.]])
>>> np.vsplit(x, 2)
[array([[0., 1., 2., 3.],
[4., 5., 6., 7.]]), array([[ 8., 9., 10., 11.],
[12., 13., 14., 15.]])]
With a higher dimensional array the split is still along the first axis.
>>> x = np.arange(8.0).reshape(2, 2, 2)
>>> x
array([[[ 0., 1.],
[ 2., 3.]],
[[ 4., 5.],
[ 6., 7.]]])
>>> np.vsplit(x, 2)
[array([[[0., 1.],
[2., 3.]]]), array([[[4., 5.],
[6., 7.]]])]
"""
return split(ary, indices_or_sections, 0)


@set_module('mxnet.ndarray.numpy')
def concatenate(seq, axis=0, out=None):
"""Join a sequence of arrays along an existing axis.
Expand Down
90 changes: 84 additions & 6 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,12 @@
'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log',
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
'tensordot', 'histogram', 'eye', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
'stack', 'vstack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var',
'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot',
'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal']
'tensordot', 'histogram', 'eye', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit',
'concatenate', 'stack', 'vstack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip',
'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip',
'around', 'arctan2', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take',
'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal',
'less_equal']


# Return code for dispatching indexing function call
Expand Down Expand Up @@ -3922,7 +3923,7 @@ def split(ary, indices_or_sections, axis=0):
----------
ary : ndarray
Array to be divided into sub-arrays.
indices_or_sections : int or 1-D array
indices_or_sections : int or 1-D Python tuple, list or set.
If `indices_or_sections` is an integer, N, the array will be divided
into N equal arrays along `axis`. If such a split is not possible,
an error is raised.
Expand All @@ -3948,6 +3949,83 @@ def split(ary, indices_or_sections, axis=0):
return _mx_nd_np.split(ary, indices_or_sections, axis=axis)


@set_module('mxnet.numpy')
def vsplit(ary, indices_or_sections):
r"""
vsplit(ary, indices_or_sections)
Split an array into multiple sub-arrays vertically (row-wise).
``vsplit`` is equivalent to ``split`` with `axis=0` (default): the array is always split
along the first axis regardless of the array dimension.
Parameters
----------
ary : ndarray
Array to be divided into sub-arrays.
indices_or_sections : int or 1 - D Python tuple, list or set.
If `indices_or_sections` is an integer, N, the array will be divided into N equal arrays
along axis 0. If such a split is not possible, an error is raised.
If `indices_or_sections` is a 1-D array of sorted integers, the entries indicate where
along axis 0 the array is split. For example, ``[2, 3]`` would result in
- ary[:2]
- ary[2:3]
- ary[3:]
If an index exceeds the dimension of the array along axis 0, an error will be thrown.
Returns
-------
sub-arrays : list of ndarrays
A list of sub-arrays.
See Also
--------
split : Split an array into multiple sub-arrays of equal size.
Notes
-------
This function differs from the original `numpy.degrees
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.degrees.html>`_ in
the following aspects:
- Currently parameter ``indices_or_sections`` does not support ndarray, but supports scalar,
tuple and list.
- In ``indices_or_sections``, if an index exceeds the dimension of the array along axis 0,
an error will be thrown.
Examples
--------
>>> x = np.arange(16.0).reshape(4, 4)
>>> x
array([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[ 12., 13., 14., 15.]])
>>> np.vsplit(x, 2)
[array([[0., 1., 2., 3.],
[4., 5., 6., 7.]]), array([[ 8., 9., 10., 11.],
[12., 13., 14., 15.]])]
With a higher dimensional array the split is still along the first axis.
>>> x = np.arange(8.0).reshape(2, 2, 2)
>>> x
array([[[ 0., 1.],
[ 2., 3.]],
[[ 4., 5.],
[ 6., 7.]]])
>>> np.vsplit(x, 2)
[array([[[0., 1.],
[2., 3.]]]), array([[[4., 5.],
[6., 7.]]])]
"""
return split(ary, indices_or_sections, 0)


@set_module('mxnet.numpy')
def concatenate(seq, axis=0, out=None):
"""Join a sequence of arrays along an existing axis.
Expand Down
67 changes: 59 additions & 8 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@
'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'dstack',
'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal']
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'stack', 'vstack',
'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices',
'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot',
'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot',
'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal']


def _num_outputs(sym):
Expand Down Expand Up @@ -2573,7 +2573,7 @@ def split(ary, indices_or_sections, axis=0):
----------
ary : ndarray
Array to be divided into sub-arrays.
indices_or_sections : int or 1-D array
indices_or_sections : int or 1-D python tuple, list or set.
If `indices_or_sections` is an integer, N, the array will be divided
into N equal arrays along `axis`. If such a split is not possible,
an error is raised.
Expand All @@ -2600,15 +2600,66 @@ def split(ary, indices_or_sections, axis=0):
sections = 0
if isinstance(indices_or_sections, int):
sections = indices_or_sections
elif isinstance(indices_or_sections, tuple):
elif isinstance(indices_or_sections, (list, set, tuple)):
indices = [0] + list(indices_or_sections)
else:
raise ValueError('indices_or_sections must either int or tuple of ints')
raise ValueError('indices_or_sections must either int or tuple / list / set of ints')
ret = _npi.split(ary, indices, axis, False, sections)
return ret
# pylint: enable=redefined-outer-name


@set_module('mxnet.symbol.numpy')
def vsplit(ary, indices_or_sections):
r"""
vsplit(ary, indices_or_sections)
Split an array into multiple sub-arrays vertically (row-wise).
``vsplit`` is equivalent to ``split`` with `axis=0` (default): the array is always split
along the first axis regardless of the array dimension.
Parameters
----------
ary : _Symbol
Array to be divided into sub-arrays.
indices_or_sections : int or 1 - D Python tuple, list or set.
If `indices_or_sections` is an integer, N, the array will be divided into N equal arrays
along axis 0. If such a split is not possible, an error is raised.
If `indices_or_sections` is a 1-D array of sorted integers, the entries indicate where
along axis 0 the array is split. For example, ``[2, 3]`` would result in
- ary[:2]
- ary[2:3]
- ary[3:]
If an index exceeds the dimension of the array along axis 0, an error will be thrown.
Returns
-------
sub-arrays : list of _Symbols
A list of sub-arrays.
See Also
--------
split : Split an array into multiple sub-arrays of equal size.
Notes
-------
This function differs from the original `numpy.degrees
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.degrees.html>`_ in
the following aspects:
- Currently parameter ``indices_or_sections`` does not support ndarray, but supports scalar,
tuple and list
- In ``indices_or_sections``, if an index exceeds the dimension of the array along axis 0,
an error will be thrown.
"""
return split(ary, indices_or_sections, 0)


@set_module('mxnet.symbol.numpy')
def concatenate(seq, axis=0, out=None):
"""Join a sequence of arrays along an existing axis.
Expand Down
55 changes: 55 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1583,6 +1583,61 @@ def get_indices(axis_size):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


@with_seed()
@use_np
def test_np_vsplit():
class TestVsplit(HybridBlock):
def __init__(self, indices_or_sections):
super(TestVsplit, self).__init__()
self._indices_or_sections = indices_or_sections

def hybrid_forward(self, F, a, *args, **kwargs):
return F.np.vsplit(a, indices_or_sections=self._indices_or_sections)

def get_indices(axis_size):
if axis_size is 0:
axis_size = random.randint(3, 6)
samples = random.randint(1, axis_size - 1)
indices = sorted(random.sample([i for i in range(1, axis_size)], samples))
indices = tuple(indices)
return indices

shapes = [
(2, 1, 2, 9),
(4, 3, 3),
(4, 0, 2), # zero-size shape
(0, 3), # first dim being zero
]
for hybridize in [True, False]:
for shape in shapes:
axis_size = shape[0]
indices = get_indices(axis_size)
sections = 7 if axis_size is 0 else axis_size
for indices_or_sections in [indices, sections]:
# test gluon
test_vsplit = TestVsplit(indices_or_sections=indices_or_sections)
if hybridize:
test_vsplit.hybridize()
a = rand_ndarray(shape).as_np_ndarray() # TODO: check type
a.attach_grad()
expected_ret = _np.vsplit(a.asnumpy(), indices_or_sections=indices_or_sections)
with mx.autograd.record():
y = test_vsplit(a)
assert len(y) == len(expected_ret)
for mx_out, np_out in zip(y, expected_ret):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)

mx.autograd.backward(y)

assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5)

# test imperative
mx_outs = np.vsplit(a, indices_or_sections=indices_or_sections)
np_outs = _np.vsplit(a.asnumpy(), indices_or_sections=indices_or_sections)
for mx_out, np_out in zip(mx_outs, np_outs):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


@with_seed()
@use_np
def test_np_concat():
Expand Down

0 comments on commit 4b535a8

Please sign in to comment.