Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed May 30, 2019
1 parent a079144 commit 2d95b93
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 8 deletions.
6 changes: 4 additions & 2 deletions python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,6 @@ def minimum(x1, x2, out=None):


@set_module('mxnet.ndarray.numpy')
@use_np_compat
def split(ary, indices_or_sections, axis=0):
"""Split an array into multiple sub-arrays.
Expand Down Expand Up @@ -223,4 +222,7 @@ def split(ary, indices_or_sections, axis=0):
indices = [0] + list(indices_or_sections)
else:
raise ValueError('indices_or_sections must either int or tuple of ints')
return _npi.split(ary, indices, axis, False)
ret = _npi.split(ary, indices, axis, False)
if not isinstance(ret, list):
raise NotImplementedError('single output from split is not supported yet...')
return ret
6 changes: 4 additions & 2 deletions python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,7 +1001,6 @@ def minimum(x1, x2, out=None):


@set_module('mxnet.symbol.numpy')
@use_np_compat
def split(ary, indices_or_sections, axis=0):
"""Split an array into multiple sub-arrays.
Expand Down Expand Up @@ -1045,6 +1044,9 @@ def split(ary, indices_or_sections, axis=0):
indices = [0] + list(indices_or_sections)
else:
raise ValueError('indices_or_sections must either int or tuple of ints')
return _npi.split(ary, indices, axis, sections)
ret = _npi.split(ary, indices, axis, False, sections)
# print(type(ret))
# ret = [ret] if not isinstance(ret, list) else ret
return ret

_set_np_symbol_class(_Symbol)
12 changes: 8 additions & 4 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -2637,10 +2637,14 @@ inline bool SplitOpShape(const nnvm::NodeAttrs& attrs,
for (int i = 0; i < num_outputs; ++i) {
int start = indices[i];
int end = (i < num_outputs - 1) ? indices[i + 1] : ishape[real_axis];
CHECK(start < end)
<< "start " << start << " is not less than end " << end << "for subarray " << i;
CHECK(end <= ishape[real_axis])
<< "end " << end << " is no less than the size of the axis " << ishape[real_axis];
if (ishape[real_axis] == 0U) {
end = start;
} else {
CHECK(start < end)
<< "start " << start << " is not less than end " << end << "for subarray " << i;
CHECK(end <= ishape[real_axis])
<< "end " << end << " is no less than the size of the axis " << ishape[real_axis];
}
dshape[real_axis] = (end - start);
if (param.squeeze_axis) {
CHECK_EQ(end - start, 1U) << "expected axis size of 1 but got " << end - start;
Expand Down
53 changes: 53 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,59 @@ def check_minimum(x1, x2):
check_minimum(np.zeros(()), np.ones((5, 1, 4)))


@with_seed()
@np.use_np_shape
def test_np_split():
class TestSplit(HybridBlock):
def __init__(self, indices_or_sections, axis=None):
super(TestSplit, self).__init__()
self._axis = axis
self._indices_or_sections = indices_or_sections

def hybrid_forward(self, F, a, *args, **kwargs):
return F.np.split(a, indices_or_sections=self._indices_or_sections,
axis=self._axis)

def get_indices(axis_size):
if axis_size is 0:
axis_size = random.randint(3, 6)
samples = random.randint(1, axis_size - 1)
indices = sorted(random.sample([i for i in range(1, axis_size)], samples))
indices = tuple(indices)
return indices

dim = random.randint(0, 3)
shape = [0] + [random.randint(2, 4) for i in range(dim)]
for hybridize in [True, False]:
for axis in range(len(shape)):
indices = get_indices(shape[axis])
sections = 7 if shape[axis] is 0 else shape[axis]
for indices_or_sections in [indices, sections]:
# test gluon
test_split = TestSplit(axis=axis, indices_or_sections=indices_or_sections)
if hybridize:
test_split.hybridize()

a = mx.nd.random.uniform(-1.0, 1.0, shape=shape).as_np_ndarray()
a.attach_grad()
expected_ret = _np.split(a.asnumpy(), indices_or_sections=indices_or_sections, axis=axis)
with mx.autograd.record():
y = test_split(a)
assert len(y) == len(expected_ret)
for mx_out, np_out in zip(y, expected_ret):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)

mx.autograd.backward(y)

assert_almost_equal(a.grad.asnumpy(), _np.ones(a.shape), rtol=1e-3, atol=1e-5)

# test imperative
mx_outs = np.split(a, indices_or_sections=indices_or_sections, axis=axis)
np_outs = _np.split(a.asnumpy(), indices_or_sections=indices_or_sections, axis=axis)
for mx_out, np_out in zip(mx_outs, np_outs):
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 2d95b93

Please sign in to comment.