From 26526ba16bd8dfb0819d5e97a68268ccc65ad234 Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Sat, 10 Aug 2019 04:35:54 +0000 Subject: [PATCH] numpy-compatible split upstream --- python/mxnet/ndarray/numpy/_op.py | 53 +++++++++++++++++++++++++- python/mxnet/numpy/multiarray.py | 36 ++++++++++++++++- python/mxnet/symbol/numpy/_symbol.py | 45 +++++++++++++++++++++- tests/python/unittest/test_numpy_op.py | 53 ++++++++++++++++++++++++++ 4 files changed, 184 insertions(+), 3 deletions(-) diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 30cbc6999b88..ff4dab40d296 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -27,7 +27,8 @@ from ..ndarray import NDArray __all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot', - 'linspace', 'expand_dims', 'tile', 'arange'] + 'linspace', 'expand_dims', 'tile', 'arange', 'split'] + 'split'] @set_module('mxnet.ndarray.numpy') @@ -420,6 +421,7 @@ def tensordot(a, b, axes=2): return _npi.tensordot(a, b, a_axes_summed, b_axes_summed) +@set_module('mxnet.ndarray.numpy') def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, ctx=None): # pylint: disable=too-many-arguments r""" Return evenly spaced numbers over a specified interval. @@ -632,3 +634,52 @@ def tile(A, reps): """ return _unary_func_helper(A, _npi.tile, _np.tile, reps=reps) + + +@set_module('mxnet.ndarray.numpy') +def split(ary, indices_or_sections, axis=0): + """Split an array into multiple sub-arrays. + Parameters + ---------- + ary : ndarray + Array to be divided into sub-arrays. + indices_or_sections : int or 1-D array + If `indices_or_sections` is an integer, N, the array will be divided + into N equal arrays along `axis`. If such a split is not possible, + an error is raised. + If `indices_or_sections` is a 1-D array of sorted integers, the entries + indicate where along `axis` the array is split. For example, + ``[2, 3]`` would, for ``axis=0``, result in + - ary[:2] + - ary[2:3] + - ary[3:] + If an index exceeds the dimension of the array along `axis`, + an empty sub-array is returned correspondingly. + axis : int, optional + The axis along which to split, default is 0. + Returns + ------- + sub-arrays : list of ndarrays + A list of sub-arrays. + Raises + ------ + ValueError + If `indices_or_sections` is given as an integer, but + a split does not result in equal division. + """ + indices = [] + axis_size = ary.shape[axis] + if isinstance(indices_or_sections, int): + sections = indices_or_sections + if axis_size % sections: + raise ValueError('array split does not result in an equal division') + section_size = int(axis_size / sections) + indices = [i * section_size for i in range(sections)] + elif isinstance(indices_or_sections, tuple): + indices = [0] + list(indices_or_sections) + else: + raise ValueError('indices_or_sections must either int or tuple of ints') + ret = _npi.split(ary, indices, axis, False) + if not isinstance(ret, list): + return [ret] + return ret diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index b8633699435d..0f05aa8da50e 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -44,7 +44,7 @@ from ..ndarray.numpy import _internal as _npi __all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', - 'mod', 'power', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange'] + 'mod', 'power', 'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split'] # This function is copied from ndarray.py since pylint @@ -1606,6 +1606,7 @@ def tensordot(a, b, axes=2): return _mx_nd_np.tensordot(a, b, axes) +@set_module('mxnet.numpy') def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, ctx=None): # pylint: disable=too-many-arguments r""" Return evenly spaced numbers over a specified interval. @@ -1819,3 +1820,36 @@ def arange(start, stop=None, step=1, dtype=None, ctx=None): than `stop`. """ return _mx_nd_np.arange(start, stop, step, dtype, ctx) + + +@set_module('mxnet.numpy') +def split(ary, indices_or_sections, axis=0): + """Split an array into multiple sub-arrays. + Parameters + ---------- + ary : ndarray + Array to be divided into sub-arrays. + indices_or_sections : int or 1-D array + If `indices_or_sections` is an integer, N, the array will be divided + into N equal arrays along `axis`. If such a split is not possible, + an error is raised. + If `indices_or_sections` is a 1-D array of sorted integers, the entries + indicate where along `axis` the array is split. For example, + ``[2, 3]`` would, for ``axis=0``, result in + - ary[:2] + - ary[2:3] + - ary[3:] + If an index exceeds the dimension of the array along `axis`, + an empty sub-array is returned correspondingly. + axis : int, optional + The axis along which to split, default is 0. + Returns + ------- + sub-arrays : list of ndarrays + A list of sub-arrays. + Raises + ------ + ValueError + If `indices_or_sections` is given as an integer, but + a split does not result in equal division.""" + return _mx_nd_np.split(ary, indices_or_sections, axis=axis) diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index 587dae3deab9..bf4a6d159363 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -30,7 +30,7 @@ from . import _internal as _npi __all__ = ['zeros', 'ones', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'tensordot', - 'linspace', 'expand_dims', 'tile', 'arange'] + 'linspace', 'expand_dims', 'tile', 'arange', 'split'] def _num_outputs(sym): @@ -1063,6 +1063,7 @@ def tensordot(a, b, axes=2): return _npi.tensordot(a, b, a_axes_summed, b_axes_summed) +@set_module('mxnet.symbol.numpy') def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0, ctx=None): # pylint: disable=too-many-arguments r""" Return evenly spaced numbers over a specified interval. @@ -1269,4 +1270,46 @@ def arange(start, stop=None, step=1, dtype=None, ctx=None): return _npi.arange(start=start, stop=stop, step=step, dtype=dtype, ctx=ctx) +@set_module('mxnet.symbol.numpy') +def split(ary, indices_or_sections, axis=0): + """Split an array into multiple sub-arrays. + Parameters + ---------- + ary : ndarray + Array to be divided into sub-arrays. + indices_or_sections : int or 1-D array + If `indices_or_sections` is an integer, N, the array will be divided + into N equal arrays along `axis`. If such a split is not possible, + an error is raised. + If `indices_or_sections` is a 1-D array of sorted integers, the entries + indicate where along `axis` the array is split. For example, + ``[2, 3]`` would, for ``axis=0``, result in + - ary[:2] + - ary[2:3] + - ary[3:] + If an index exceeds the dimension of the array along `axis`, + an empty sub-array is returned correspondingly. + axis : int, optional + The axis along which to split, default is 0. + Returns + ------- + sub-arrays : list of ndarrays + A list of sub-arrays. + Raises + ------ + ValueError + If `indices_or_sections` is given as an integer, but + a split does not result in equal division.""" + indices = [] + sections = 0 + if isinstance(indices_or_sections, int): + sections = indices_or_sections + elif isinstance(indices_or_sections, tuple): + indices = [0] + list(indices_or_sections) + else: + raise ValueError('indices_or_sections must either int or tuple of ints') + ret = _npi.split(ary, indices, axis, False, sections) + return ret + + _set_np_symbol_class(_Symbol) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 9d130e701491..aba262f88042 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -776,6 +776,59 @@ def hybrid_forward(self, F, x): assert same(mx_out.asnumpy(), np_out) +@with_seed() +@use_np +def test_np_split(): + class TestSplit(HybridBlock): + def __init__(self, indices_or_sections, axis=None): + super(TestSplit, self).__init__() + self._axis = axis + self._indices_or_sections = indices_or_sections + + def hybrid_forward(self, F, a, *args, **kwargs): + return F.np.split(a, indices_or_sections=self._indices_or_sections, + axis=self._axis) + + def get_indices(axis_size): + if axis_size is 0: + axis_size = random.randint(3, 6) + samples = random.randint(1, axis_size - 1) + indices = sorted(random.sample([i for i in range(1, axis_size)], samples)) + indices = tuple(indices) + return indices + + dim = random.randint(0, 3) + shape = [0] + [random.randint(2, 4) for i in range(dim)] + for hybridize in [True, False]: + for axis in range(len(shape)): + indices = get_indices(shape[axis]) + sections = 7 if shape[axis] is 0 else shape[axis] + for indices_or_sections in [indices, sections]: + # test gluon + test_split = TestSplit(axis=axis, indices_or_sections=indices_or_sections) + if hybridize: + test_split.hybridize() + + a = mx.nd.random.uniform(-1.0, 1.0, shape=shape).as_np_ndarray() + a.attach_grad() + expected_ret = _np.split(a.asnumpy(), indices_or_sections=indices_or_sections, axis=axis) + with mx.autograd.record(): + y = test_split(a) + assert len(y) == len(expected_ret) + for mx_out, np_out in zip(y, expected_ret): + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + + mx.autograd.backward(y) + + assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5) + + # test imperative + mx_outs = np.split(a, indices_or_sections=indices_or_sections, axis=axis) + np_outs = _np.split(a.asnumpy(), indices_or_sections=indices_or_sections, axis=axis) + for mx_out, np_out in zip(mx_outs, np_outs): + assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5) + + if __name__ == '__main__': import nose nose.runmodule()