Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Numpy compatible vsplit; minor change to numpy split #15983

Merged
merged 2 commits into from
Oct 13, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 85 additions & 9 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@
'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2',
'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'dstack',
'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal']
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'stack',
'vstack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var',
'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot',
'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot',
'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -823,7 +823,6 @@ def eye(N, M=None, k=0, dtype=_np.float32, **kwargs):
def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, ctx=None): # pylint: disable=too-many-arguments
r"""
Return evenly spaced numbers over a specified interval.

Returns num evenly spaced samples, calculated over the interval [start, stop].
The endpoint of the interval can optionally be excluded.

Expand Down Expand Up @@ -2354,7 +2353,7 @@ def split(ary, indices_or_sections, axis=0):
----------
ary : ndarray
Array to be divided into sub-arrays.
indices_or_sections : int or 1-D array
indices_or_sections : int or 1-D python tuple, list or set.
If `indices_or_sections` is an integer, N, the array will be divided
into N equal arrays along `axis`. If such a split is not possible,
an error is raised.
Expand Down Expand Up @@ -2386,17 +2385,94 @@ def split(ary, indices_or_sections, axis=0):
raise ValueError('array split does not result in an equal division')
section_size = int(axis_size / sections)
indices = [i * section_size for i in range(sections)]
elif isinstance(indices_or_sections, tuple):
elif isinstance(indices_or_sections, (list, set, tuple)):
indices = [0] + list(indices_or_sections)
else:
raise ValueError('indices_or_sections must either int or tuple of ints')
raise ValueError('indices_or_sections must either int, or tuple / list / set of ints')
ret = _npi.split(ary, indices, axis, False)
if not isinstance(ret, list):
return [ret]
return ret
# pylint: enable=redefined-outer-name


@set_module('mxnet.ndarray.numpy')
def vsplit(ary, indices_or_sections):
r"""
vsplit(ary, indices_or_sections)

Split an array into multiple sub-arrays vertically (row-wise).

``vsplit`` is equivalent to ``split`` with `axis=0` (default): the array is always split
along the first axis regardless of the array dimension.

Parameters
----------
ary : ndarray
Array to be divided into sub-arrays.
indices_or_sections : int or 1 - D Python tuple, list or set.
If `indices_or_sections` is an integer, N, the array will be divided into N equal arrays
along axis 0. If such a split is not possible, an error is raised.

If `indices_or_sections` is a 1-D array of sorted integers, the entries indicate where
along axis 0 the array is split. For example, ``[2, 3]`` would result in

- ary[:2]
- ary[2:3]
- ary[3:]

If an index exceeds the dimension of the array along axis 0, an error will be thrown.

Returns
-------
sub-arrays : list of ndarrays
A list of sub-arrays.

See Also
--------
split : Split an array into multiple sub-arrays of equal size.

Notes
-------
This function differs from the original `numpy.degrees
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.degrees.html>`_ in
the following aspects:

- Currently parameter ``indices_or_sections`` does not support ndarray, but supports scalar,
tuple and list.
- In ``indices_or_sections``, if an index exceeds the dimension of the array along axis 0,
an error will be thrown.

Examples
--------
>>> x = np.arange(16.0).reshape(4, 4)
>>> x
array([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[ 12., 13., 14., 15.]])
>>> np.vsplit(x, 2)
[array([[0., 1., 2., 3.],
[4., 5., 6., 7.]]), array([[ 8., 9., 10., 11.],
[12., 13., 14., 15.]])]

With a higher dimensional array the split is still along the first axis.

>>> x = np.arange(8.0).reshape(2, 2, 2)
>>> x
array([[[ 0., 1.],
[ 2., 3.]],
[[ 4., 5.],
[ 6., 7.]]])
>>> np.vsplit(x, 2)
[array([[[0., 1.],
[2., 3.]]]), array([[[4., 5.],
[6., 7.]]])]

"""
return split(ary, indices_or_sections, 0)


@set_module('mxnet.ndarray.numpy')
def concatenate(seq, axis=0, out=None):
"""Join a sequence of arrays along an existing axis.
Expand Down
90 changes: 84 additions & 6 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,12 @@
'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log',
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
'tensordot', 'histogram', 'eye', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
'stack', 'vstack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var',
'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot',
'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal']
'tensordot', 'histogram', 'eye', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit',
'concatenate', 'stack', 'vstack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip',
'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip',
'around', 'arctan2', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take',
'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal',
'less_equal']


# Return code for dispatching indexing function call
Expand Down Expand Up @@ -3922,7 +3923,7 @@ def split(ary, indices_or_sections, axis=0):
----------
ary : ndarray
Array to be divided into sub-arrays.
indices_or_sections : int or 1-D array
indices_or_sections : int or 1-D Python tuple, list or set.
If `indices_or_sections` is an integer, N, the array will be divided
into N equal arrays along `axis`. If such a split is not possible,
an error is raised.
Expand All @@ -3948,6 +3949,83 @@ def split(ary, indices_or_sections, axis=0):
return _mx_nd_np.split(ary, indices_or_sections, axis=axis)


@set_module('mxnet.numpy')
def vsplit(ary, indices_or_sections):
r"""
vsplit(ary, indices_or_sections)

Split an array into multiple sub-arrays vertically (row-wise).

``vsplit`` is equivalent to ``split`` with `axis=0` (default): the array is always split
along the first axis regardless of the array dimension.

Parameters
----------
ary : ndarray
Array to be divided into sub-arrays.
indices_or_sections : int or 1 - D Python tuple, list or set.
If `indices_or_sections` is an integer, N, the array will be divided into N equal arrays
along axis 0. If such a split is not possible, an error is raised.

If `indices_or_sections` is a 1-D array of sorted integers, the entries indicate where
along axis 0 the array is split. For example, ``[2, 3]`` would result in

- ary[:2]
- ary[2:3]
- ary[3:]

If an index exceeds the dimension of the array along axis 0, an error will be thrown.

Returns
-------
sub-arrays : list of ndarrays
A list of sub-arrays.

See Also
--------
split : Split an array into multiple sub-arrays of equal size.

Notes
-------
This function differs from the original `numpy.degrees
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.degrees.html>`_ in
the following aspects:

- Currently parameter ``indices_or_sections`` does not support ndarray, but supports scalar,
tuple and list.
- In ``indices_or_sections``, if an index exceeds the dimension of the array along axis 0,
an error will be thrown.

Examples
--------
>>> x = np.arange(16.0).reshape(4, 4)
>>> x
array([[ 0., 1., 2., 3.],
[ 4., 5., 6., 7.],
[ 8., 9., 10., 11.],
[ 12., 13., 14., 15.]])
>>> np.vsplit(x, 2)
[array([[0., 1., 2., 3.],
[4., 5., 6., 7.]]), array([[ 8., 9., 10., 11.],
[12., 13., 14., 15.]])]

With a higher dimensional array the split is still along the first axis.

>>> x = np.arange(8.0).reshape(2, 2, 2)
>>> x
array([[[ 0., 1.],
[ 2., 3.]],
[[ 4., 5.],
[ 6., 7.]]])
>>> np.vsplit(x, 2)
[array([[[0., 1.],
[2., 3.]]]), array([[[4., 5.],
[6., 7.]]])]

"""
return split(ary, indices_or_sections, 0)


@set_module('mxnet.numpy')
def concatenate(seq, axis=0, out=None):
"""Join a sequence of arrays along an existing axis.
Expand Down
67 changes: 59 additions & 8 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@
'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p',
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'dstack',
'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign',
'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad',
'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal']
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'stack', 'vstack',
'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices',
'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot',
'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot',
'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal']


def _num_outputs(sym):
Expand Down Expand Up @@ -2573,7 +2573,7 @@ def split(ary, indices_or_sections, axis=0):
----------
ary : ndarray
Array to be divided into sub-arrays.
indices_or_sections : int or 1-D array
indices_or_sections : int or 1-D python tuple, list or set.
If `indices_or_sections` is an integer, N, the array will be divided
into N equal arrays along `axis`. If such a split is not possible,
an error is raised.
Expand All @@ -2600,15 +2600,66 @@ def split(ary, indices_or_sections, axis=0):
sections = 0
if isinstance(indices_or_sections, int):
sections = indices_or_sections
elif isinstance(indices_or_sections, tuple):
elif isinstance(indices_or_sections, (list, set, tuple)):
indices = [0] + list(indices_or_sections)
else:
raise ValueError('indices_or_sections must either int or tuple of ints')
raise ValueError('indices_or_sections must either int or tuple / list / set of ints')
ret = _npi.split(ary, indices, axis, False, sections)
return ret
# pylint: enable=redefined-outer-name


@set_module('mxnet.symbol.numpy')
def vsplit(ary, indices_or_sections):
r"""
vsplit(ary, indices_or_sections)

Split an array into multiple sub-arrays vertically (row-wise).

``vsplit`` is equivalent to ``split`` with `axis=0` (default): the array is always split
along the first axis regardless of the array dimension.

Parameters
----------
ary : _Symbol
Array to be divided into sub-arrays.
indices_or_sections : int or 1 - D Python tuple, list or set.
If `indices_or_sections` is an integer, N, the array will be divided into N equal arrays
along axis 0. If such a split is not possible, an error is raised.

If `indices_or_sections` is a 1-D array of sorted integers, the entries indicate where
along axis 0 the array is split. For example, ``[2, 3]`` would result in

- ary[:2]
- ary[2:3]
- ary[3:]

If an index exceeds the dimension of the array along axis 0, an error will be thrown.

Returns
-------
sub-arrays : list of _Symbols
A list of sub-arrays.

See Also
--------
split : Split an array into multiple sub-arrays of equal size.

Notes
-------
This function differs from the original `numpy.degrees
<https://docs.scipy.org/doc/numpy/reference/generated/numpy.degrees.html>`_ in
the following aspects:

- Currently parameter ``indices_or_sections`` does not support ndarray, but supports scalar,
tuple and list
- In ``indices_or_sections``, if an index exceeds the dimension of the array along axis 0,
an error will be thrown.

"""
return split(ary, indices_or_sections, 0)


@set_module('mxnet.symbol.numpy')
def concatenate(seq, axis=0, out=None):
"""Join a sequence of arrays along an existing axis.
Expand Down
55 changes: 55 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1583,6 +1583,61 @@ def get_indices(axis_size):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


@with_seed()
@use_np
def test_np_vsplit():
class TestVsplit(HybridBlock):
def __init__(self, indices_or_sections):
super(TestVsplit, self).__init__()
self._indices_or_sections = indices_or_sections

def hybrid_forward(self, F, a, *args, **kwargs):
return F.np.vsplit(a, indices_or_sections=self._indices_or_sections)

def get_indices(axis_size):
if axis_size is 0:
axis_size = random.randint(3, 6)
samples = random.randint(1, axis_size - 1)
indices = sorted(random.sample([i for i in range(1, axis_size)], samples))
indices = tuple(indices)
return indices

shapes = [
(2, 1, 2, 9),
(4, 3, 3),
(4, 0, 2), # zero-size shape
(0, 3), # first dim being zero
]
for hybridize in [True, False]:
for shape in shapes:
axis_size = shape[0]
indices = get_indices(axis_size)
sections = 7 if axis_size is 0 else axis_size
for indices_or_sections in [indices, sections]:
# test gluon
test_vsplit = TestVsplit(indices_or_sections=indices_or_sections)
if hybridize:
test_vsplit.hybridize()
a = rand_ndarray(shape).as_np_ndarray() # TODO: check type
a.attach_grad()
expected_ret = _np.vsplit(a.asnumpy(), indices_or_sections=indices_or_sections)
with mx.autograd.record():
y = test_vsplit(a)
assert len(y) == len(expected_ret)
for mx_out, np_out in zip(y, expected_ret):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)

mx.autograd.backward(y)

assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5)

# test imperative
mx_outs = np.vsplit(a, indices_or_sections=indices_or_sections)
np_outs = _np.vsplit(a.asnumpy(), indices_or_sections=indices_or_sections)
for mx_out, np_out in zip(mx_outs, np_outs):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


@with_seed()
@use_np
def test_np_concat():
Expand Down