Skip to content

Commit

Permalink
Numpy-compatible split (apache#15049)
Browse files Browse the repository at this point in the history
* numpy split

* numpy split

* unit test

* unit test
  • Loading branch information
haojin2 authored and reminisce committed Aug 1, 2019
1 parent 0fd2f8d commit 6bba864
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 10 deletions.
57 changes: 56 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange', 'argmax',
'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'concatenate',
'clip', 'swapaxes', 'expand_dims']
'clip', 'split', 'swapaxes', 'expand_dims']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -538,3 +538,58 @@ def expand_dims(a, axis):
the input array.
"""
return _npi.expand_dims(a, axis)


@set_module('mxnet.ndarray.numpy')
def split(ary, indices_or_sections, axis=0):
"""Split an array into multiple sub-arrays.
Parameters
----------
ary : ndarray
Array to be divided into sub-arrays.
indices_or_sections : int or 1-D array
If `indices_or_sections` is an integer, N, the array will be divided
into N equal arrays along `axis`. If such a split is not possible,
an error is raised.
If `indices_or_sections` is a 1-D array of sorted integers, the entries
indicate where along `axis` the array is split. For example,
``[2, 3]`` would, for ``axis=0``, result in
- ary[:2]
- ary[2:3]
- ary[3:]
If an index exceeds the dimension of the array along `axis`,
an empty sub-array is returned correspondingly.
axis : int, optional
The axis along which to split, default is 0.
Returns
-------
sub-arrays : list of ndarrays
A list of sub-arrays.
Raises
------
ValueError
If `indices_or_sections` is given as an integer, but
a split does not result in equal division.
"""
indices = []
axis_size = ary.shape[axis]
if isinstance(indices_or_sections, int):
sections = indices_or_sections
if axis_size % sections:
raise ValueError('array split does not result in an equal division')
section_size = int(axis_size / sections)
indices = [i * section_size for i in range(sections)]
elif isinstance(indices_or_sections, tuple):
indices = [0] + list(indices_or_sections)
else:
raise ValueError('indices_or_sections must either int or tuple of ints')
ret = _npi.split(ary, indices, axis, False)
if not isinstance(ret, list):
raise NotImplementedError('single output from split is not supported yet...')
return ret
41 changes: 40 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@

__all__ = ['ndarray', 'empty', 'array', 'zeros', 'ones', 'maximum', 'minimum', 'stack', 'arange',
'argmax', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'concatenate',
'clip', 'swapaxes', 'expand_dims']
'clip', 'split', 'swapaxes', 'expand_dims']


# This function is copied from ndarray.py since pylint
Expand Down Expand Up @@ -1718,3 +1718,42 @@ def expand_dims(a, axis):
the input array.
"""
return _npi.expand_dims(a, axis)


@set_module('mxnet.numpy')
def split(ary, indices_or_sections, axis=0):
"""Split an array into multiple sub-arrays.
Parameters
----------
ary : ndarray
Array to be divided into sub-arrays.
indices_or_sections : int or 1-D array
If `indices_or_sections` is an integer, N, the array will be divided
into N equal arrays along `axis`. If such a split is not possible,
an error is raised.
If `indices_or_sections` is a 1-D array of sorted integers, the entries
indicate where along `axis` the array is split. For example,
``[2, 3]`` would, for ``axis=0``, result in
- ary[:2]
- ary[2:3]
- ary[3:]
If an index exceeds the dimension of the array along `axis`,
an empty sub-array is returned correspondingly.
axis : int, optional
The axis along which to split, default is 0.
Returns
-------
sub-arrays : list of ndarrays
A list of sub-arrays.
Raises
------
ValueError
If `indices_or_sections` is given as an integer, but
a split does not result in equal division."""
return _mx_nd_np.split(ary, indices_or_sections, axis=axis)
50 changes: 49 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from . import _internal as _npi

__all__ = ['zeros', 'ones', 'maximum', 'minimum', 'stack', 'concatenate', 'arange', 'argmax',
'clip', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'swapaxes',
'clip', 'add', 'subtract', 'multiply', 'divide', 'mod', 'power', 'split', 'swapaxes',
'expand_dims']


Expand Down Expand Up @@ -1227,4 +1227,52 @@ def expand_dims(a, axis):
return _npi.expand_dims(a, axis)


@set_module('mxnet.symbol.numpy')
def split(ary, indices_or_sections, axis=0):
"""Split an array into multiple sub-arrays.
Parameters
----------
ary : ndarray
Array to be divided into sub-arrays.
indices_or_sections : int or 1-D array
If `indices_or_sections` is an integer, N, the array will be divided
into N equal arrays along `axis`. If such a split is not possible,
an error is raised.
If `indices_or_sections` is a 1-D array of sorted integers, the entries
indicate where along `axis` the array is split. For example,
``[2, 3]`` would, for ``axis=0``, result in
- ary[:2]
- ary[2:3]
- ary[3:]
If an index exceeds the dimension of the array along `axis`,
an empty sub-array is returned correspondingly.
axis : int, optional
The axis along which to split, default is 0.
Returns
-------
sub-arrays : list of ndarrays
A list of sub-arrays.
Raises
------
ValueError
If `indices_or_sections` is given as an integer, but
a split does not result in equal division."""
indices = []
sections = 0
if isinstance(indices_or_sections, int):
sections = indices_or_sections
elif isinstance(indices_or_sections, tuple):
indices = [0] + list(indices_or_sections)
else:
raise ValueError('indices_or_sections must either int or tuple of ints')
ret = _npi.split(ary, indices, axis, False, sections)
return ret


_set_np_symbol_class(_Symbol)
12 changes: 8 additions & 4 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2652,10 +2652,14 @@ inline bool SplitOpShape(const nnvm::NodeAttrs& attrs,
for (int i = 0; i < num_outputs; ++i) {
int start = indices[i];
int end = (i < num_outputs - 1) ? indices[i + 1] : ishape[real_axis];
CHECK(start < end)
<< "start " << start << " is not less than end " << end << "for subarray " << i;
CHECK(end <= ishape[real_axis])
<< "end " << end << " is no less than the size of the axis " << ishape[real_axis];
if (ishape[real_axis] == 0U) {
end = start;
} else {
CHECK(start < end)
<< "start " << start << " is not less than end " << end << "for subarray " << i;
CHECK(end <= ishape[real_axis])
<< "end " << end << " is no less than the size of the axis " << ishape[real_axis];
}
dshape[real_axis] = (end - start);
if (param.squeeze_axis) {
CHECK_EQ(end - start, 1U) << "expected axis size of 1 but got " << end - start;
Expand Down
1 change: 1 addition & 0 deletions src/operator/tensor/matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1124,6 +1124,7 @@ Example::
.add_arguments(DepthToSpaceParam::__FIELDS__());

NNVM_REGISTER_OP(_split_v2)
.add_alias("_npi_split")
.describe(R"code(Splits an array along a particular axis into multiple sub-arrays.
Example::
Expand Down
56 changes: 53 additions & 3 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,6 @@ def __init__(self, func):
def hybrid_forward(self, F, a, *args, **kwargs):
return getattr(F.np, self._func)(a)

print(func)
np_func = getattr(_np, func)
mx_func = TestUnary(func)
np_test_data = _np.random.uniform(low, high, shape).astype(_np.float32)
Expand All @@ -350,8 +349,6 @@ def hybrid_forward(self, F, a, *args, **kwargs):

if ref_grad:
y.backward()
print(mx_test_data.grad.asnumpy())
print(ref_grad(np_test_data))
assert_almost_equal(mx_test_data.grad.asnumpy(), ref_grad(np_test_data), rtol=1e-5, atol=1e-6, equal_nan=True)

funcs = {
Expand Down Expand Up @@ -767,6 +764,59 @@ def hybrid_forward(self, F, x):
assert same(ret_mx.asnumpy(), ret_np)


@with_seed()
@npx.use_np_shape
def test_np_split():
class TestSplit(HybridBlock):
def __init__(self, indices_or_sections, axis=None):
super(TestSplit, self).__init__()
self._axis = axis
self._indices_or_sections = indices_or_sections

def hybrid_forward(self, F, a, *args, **kwargs):
return F.np.split(a, indices_or_sections=self._indices_or_sections,
axis=self._axis)

def get_indices(axis_size):
if axis_size is 0:
axis_size = random.randint(3, 6)
samples = random.randint(1, axis_size - 1)
indices = sorted(random.sample([i for i in range(1, axis_size)], samples))
indices = tuple(indices)
return indices

dim = random.randint(0, 3)
shape = [0] + [random.randint(2, 4) for i in range(dim)]
for hybridize in [True, False]:
for axis in range(len(shape)):
indices = get_indices(shape[axis])
sections = 7 if shape[axis] is 0 else shape[axis]
for indices_or_sections in [indices, sections]:
# test gluon
test_split = TestSplit(axis=axis, indices_or_sections=indices_or_sections)
if hybridize:
test_split.hybridize()

a = mx.nd.random.uniform(-1.0, 1.0, shape=shape).as_np_ndarray()
a.attach_grad()
expected_ret = _np.split(a.asnumpy(), indices_or_sections=indices_or_sections, axis=axis)
with mx.autograd.record():
y = test_split(a)
assert len(y) == len(expected_ret)
for mx_out, np_out in zip(y, expected_ret):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)

mx.autograd.backward(y)

assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5)

# test imperative
mx_outs = np.split(a, indices_or_sections=indices_or_sections, axis=axis)
np_outs = _np.split(a.asnumpy(), indices_or_sections=indices_or_sections, axis=axis)
for mx_out, np_out in zip(mx_outs, np_outs):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 6bba864

Please sign in to comment.