diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 0bf6232c0e78..dfb8685c45b8 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -32,11 +32,11 @@ 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye', - 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'dstack', - 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', - 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', - 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', - 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal'] + 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'stack', + 'vstack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', + 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', + 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', + 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal'] @set_module('mxnet.ndarray.numpy') @@ -823,7 +823,6 @@ def eye(N, M=None, k=0, dtype=_np.float32, **kwargs): def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, ctx=None): # pylint: disable=too-many-arguments r""" Return evenly spaced numbers over a specified interval. - Returns num evenly spaced samples, calculated over the interval [start, stop]. The endpoint of the interval can optionally be excluded. @@ -2354,7 +2353,7 @@ def split(ary, indices_or_sections, axis=0): ---------- ary : ndarray Array to be divided into sub-arrays. - indices_or_sections : int or 1-D array + indices_or_sections : int or 1-D python tuple, list or set. If `indices_or_sections` is an integer, N, the array will be divided into N equal arrays along `axis`. If such a split is not possible, an error is raised. @@ -2386,10 +2385,10 @@ def split(ary, indices_or_sections, axis=0): raise ValueError('array split does not result in an equal division') section_size = int(axis_size / sections) indices = [i * section_size for i in range(sections)] - elif isinstance(indices_or_sections, tuple): + elif isinstance(indices_or_sections, (list, set, tuple)): indices = [0] + list(indices_or_sections) else: - raise ValueError('indices_or_sections must either int or tuple of ints') + raise ValueError('indices_or_sections must either int, or tuple / list / set of ints') ret = _npi.split(ary, indices, axis, False) if not isinstance(ret, list): return [ret] @@ -2397,6 +2396,83 @@ def split(ary, indices_or_sections, axis=0): # pylint: enable=redefined-outer-name +@set_module('mxnet.ndarray.numpy') +def vsplit(ary, indices_or_sections): + r""" + vsplit(ary, indices_or_sections) + + Split an array into multiple sub-arrays vertically (row-wise). + + ``vsplit`` is equivalent to ``split`` with `axis=0` (default): the array is always split + along the first axis regardless of the array dimension. + + Parameters + ---------- + ary : ndarray + Array to be divided into sub-arrays. + indices_or_sections : int or 1 - D Python tuple, list or set. + If `indices_or_sections` is an integer, N, the array will be divided into N equal arrays + along axis 0. If such a split is not possible, an error is raised. + + If `indices_or_sections` is a 1-D array of sorted integers, the entries indicate where + along axis 0 the array is split. For example, ``[2, 3]`` would result in + + - ary[:2] + - ary[2:3] + - ary[3:] + + If an index exceeds the dimension of the array along axis 0, an error will be thrown. + + Returns + ------- + sub-arrays : list of ndarrays + A list of sub-arrays. + + See Also + -------- + split : Split an array into multiple sub-arrays of equal size. + + Notes + ------- + This function differs from the original `numpy.degrees + `_ in + the following aspects: + + - Currently parameter ``indices_or_sections`` does not support ndarray, but supports scalar, + tuple and list. + - In ``indices_or_sections``, if an index exceeds the dimension of the array along axis 0, + an error will be thrown. + + Examples + -------- + >>> x = np.arange(16.0).reshape(4, 4) + >>> x + array([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.], + [ 12., 13., 14., 15.]]) + >>> np.vsplit(x, 2) + [array([[0., 1., 2., 3.], + [4., 5., 6., 7.]]), array([[ 8., 9., 10., 11.], + [12., 13., 14., 15.]])] + + With a higher dimensional array the split is still along the first axis. + + >>> x = np.arange(8.0).reshape(2, 2, 2) + >>> x + array([[[ 0., 1.], + [ 2., 3.]], + [[ 4., 5.], + [ 6., 7.]]]) + >>> np.vsplit(x, 2) + [array([[[0., 1.], + [2., 3.]]]), array([[[4., 5.], + [6., 7.]]])] + + """ + return split(ary, indices_or_sections, 0) + + @set_module('mxnet.ndarray.numpy') def concatenate(seq, axis=0, out=None): """Join a sequence of arrays along an existing axis. diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 76df87cb01e0..a5fc9598ae9b 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -50,11 +50,12 @@ 'sqrt', 'cbrt', 'abs', 'absolute', 'exp', 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', - 'tensordot', 'histogram', 'eye', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', - 'stack', 'vstack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', - 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'arctan2', 'hypot', - 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', - 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal'] + 'tensordot', 'histogram', 'eye', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', + 'concatenate', 'stack', 'vstack', 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', + 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', + 'around', 'arctan2', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', + 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', + 'less_equal'] # Return code for dispatching indexing function call @@ -3922,7 +3923,7 @@ def split(ary, indices_or_sections, axis=0): ---------- ary : ndarray Array to be divided into sub-arrays. - indices_or_sections : int or 1-D array + indices_or_sections : int or 1-D Python tuple, list or set. If `indices_or_sections` is an integer, N, the array will be divided into N equal arrays along `axis`. If such a split is not possible, an error is raised. @@ -3948,6 +3949,83 @@ def split(ary, indices_or_sections, axis=0): return _mx_nd_np.split(ary, indices_or_sections, axis=axis) +@set_module('mxnet.numpy') +def vsplit(ary, indices_or_sections): + r""" + vsplit(ary, indices_or_sections) + + Split an array into multiple sub-arrays vertically (row-wise). + + ``vsplit`` is equivalent to ``split`` with `axis=0` (default): the array is always split + along the first axis regardless of the array dimension. + + Parameters + ---------- + ary : ndarray + Array to be divided into sub-arrays. + indices_or_sections : int or 1 - D Python tuple, list or set. + If `indices_or_sections` is an integer, N, the array will be divided into N equal arrays + along axis 0. If such a split is not possible, an error is raised. + + If `indices_or_sections` is a 1-D array of sorted integers, the entries indicate where + along axis 0 the array is split. For example, ``[2, 3]`` would result in + + - ary[:2] + - ary[2:3] + - ary[3:] + + If an index exceeds the dimension of the array along axis 0, an error will be thrown. + + Returns + ------- + sub-arrays : list of ndarrays + A list of sub-arrays. + + See Also + -------- + split : Split an array into multiple sub-arrays of equal size. + + Notes + ------- + This function differs from the original `numpy.degrees + `_ in + the following aspects: + + - Currently parameter ``indices_or_sections`` does not support ndarray, but supports scalar, + tuple and list. + - In ``indices_or_sections``, if an index exceeds the dimension of the array along axis 0, + an error will be thrown. + + Examples + -------- + >>> x = np.arange(16.0).reshape(4, 4) + >>> x + array([[ 0., 1., 2., 3.], + [ 4., 5., 6., 7.], + [ 8., 9., 10., 11.], + [ 12., 13., 14., 15.]]) + >>> np.vsplit(x, 2) + [array([[0., 1., 2., 3.], + [4., 5., 6., 7.]]), array([[ 8., 9., 10., 11.], + [12., 13., 14., 15.]])] + + With a higher dimensional array the split is still along the first axis. + + >>> x = np.arange(8.0).reshape(2, 2, 2) + >>> x + array([[[ 0., 1.], + [ 2., 3.]], + [[ 4., 5.], + [ 6., 7.]]]) + >>> np.vsplit(x, 2) + [array([[[0., 1.], + [2., 3.]]]), array([[[4., 5.], + [6., 7.]]])] + + """ + return split(ary, indices_or_sections, 0) + + @set_module('mxnet.numpy') def concatenate(seq, axis=0, out=None): """Join a sequence of arrays along an existing axis. diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index cbd46f3e0b07..7ed23692e5e1 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -34,11 +34,11 @@ 'expm1', 'arcsin', 'arccos', 'arctan', 'sign', 'log', 'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot', 'histogram', 'eye', - 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'vstack', 'dstack', - 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', - 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', 'rad2deg', 'deg2rad', - 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', - 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal'] + 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'vsplit', 'concatenate', 'stack', 'vstack', + 'dstack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', + 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip', 'around', 'hypot', + 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take', 'ldexp', 'vdot', + 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal'] def _num_outputs(sym): @@ -2573,7 +2573,7 @@ def split(ary, indices_or_sections, axis=0): ---------- ary : ndarray Array to be divided into sub-arrays. - indices_or_sections : int or 1-D array + indices_or_sections : int or 1-D python tuple, list or set. If `indices_or_sections` is an integer, N, the array will be divided into N equal arrays along `axis`. If such a split is not possible, an error is raised. @@ -2600,15 +2600,66 @@ def split(ary, indices_or_sections, axis=0): sections = 0 if isinstance(indices_or_sections, int): sections = indices_or_sections - elif isinstance(indices_or_sections, tuple): + elif isinstance(indices_or_sections, (list, set, tuple)): indices = [0] + list(indices_or_sections) else: - raise ValueError('indices_or_sections must either int or tuple of ints') + raise ValueError('indices_or_sections must either int or tuple / list / set of ints') ret = _npi.split(ary, indices, axis, False, sections) return ret # pylint: enable=redefined-outer-name +@set_module('mxnet.symbol.numpy') +def vsplit(ary, indices_or_sections): + r""" + vsplit(ary, indices_or_sections) + + Split an array into multiple sub-arrays vertically (row-wise). + + ``vsplit`` is equivalent to ``split`` with `axis=0` (default): the array is always split + along the first axis regardless of the array dimension. + + Parameters + ---------- + ary : _Symbol + Array to be divided into sub-arrays. + indices_or_sections : int or 1 - D Python tuple, list or set. + If `indices_or_sections` is an integer, N, the array will be divided into N equal arrays + along axis 0. If such a split is not possible, an error is raised. + + If `indices_or_sections` is a 1-D array of sorted integers, the entries indicate where + along axis 0 the array is split. For example, ``[2, 3]`` would result in + + - ary[:2] + - ary[2:3] + - ary[3:] + + If an index exceeds the dimension of the array along axis 0, an error will be thrown. + + Returns + ------- + sub-arrays : list of _Symbols + A list of sub-arrays. + + See Also + -------- + split : Split an array into multiple sub-arrays of equal size. + + Notes + ------- + This function differs from the original `numpy.degrees + `_ in + the following aspects: + + - Currently parameter ``indices_or_sections`` does not support ndarray, but supports scalar, + tuple and list + - In ``indices_or_sections``, if an index exceeds the dimension of the array along axis 0, + an error will be thrown. + + """ + return split(ary, indices_or_sections, 0) + + @set_module('mxnet.symbol.numpy') def concatenate(seq, axis=0, out=None): """Join a sequence of arrays along an existing axis. diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 39420239c919..519d2b025eeb 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -1583,6 +1583,61 @@ def get_indices(axis_size): assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) +@with_seed() +@use_np +def test_np_vsplit(): + class TestVsplit(HybridBlock): + def __init__(self, indices_or_sections): + super(TestVsplit, self).__init__() + self._indices_or_sections = indices_or_sections + + def hybrid_forward(self, F, a, *args, **kwargs): + return F.np.vsplit(a, indices_or_sections=self._indices_or_sections) + + def get_indices(axis_size): + if axis_size is 0: + axis_size = random.randint(3, 6) + samples = random.randint(1, axis_size - 1) + indices = sorted(random.sample([i for i in range(1, axis_size)], samples)) + indices = tuple(indices) + return indices + + shapes = [ + (2, 1, 2, 9), + (4, 3, 3), + (4, 0, 2), # zero-size shape + (0, 3), # first dim being zero + ] + for hybridize in [True, False]: + for shape in shapes: + axis_size = shape[0] + indices = get_indices(axis_size) + sections = 7 if axis_size is 0 else axis_size + for indices_or_sections in [indices, sections]: + # test gluon + test_vsplit = TestVsplit(indices_or_sections=indices_or_sections) + if hybridize: + test_vsplit.hybridize() + a = rand_ndarray(shape).as_np_ndarray() # TODO: check type + a.attach_grad() + expected_ret = _np.vsplit(a.asnumpy(), indices_or_sections=indices_or_sections) + with mx.autograd.record(): + y = test_vsplit(a) + assert len(y) == len(expected_ret) + for mx_out, np_out in zip(y, expected_ret): + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + + mx.autograd.backward(y) + + assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5) + + # test imperative + mx_outs = np.vsplit(a, indices_or_sections=indices_or_sections) + np_outs = _np.vsplit(a.asnumpy(), indices_or_sections=indices_or_sections) + for mx_out, np_out in zip(mx_outs, np_outs): + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + + @with_seed() @use_np def test_np_concat():