Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

refactor gluon.utils.split_data() following np.array_split() #17123

Merged
merged 6 commits into from
Jan 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 11 additions & 23 deletions python/mxnet/gluon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,30 +70,18 @@ def split_data(data, num_slice, batch_axis=0, even_split=True):
"uneven partitioning of data."%(
str(data.shape), num_slice, batch_axis, num_slice))

step = size // num_slice

# If size < num_slice, make fewer slices
if not even_split and size < num_slice:
step = 1
num_slice = size

if batch_axis == 0:
slices = [data[i*step:(i+1)*step] if i < num_slice - 1 else data[i*step:size]
for i in range(num_slice)]
elif even_split:
sxjscience marked this conversation as resolved.
Show resolved Hide resolved
if is_np_array():
slices = _mx_np.split(data, indices_or_sections=num_slice, axis=batch_axis)
else:
slices = ndarray.split(data, num_outputs=num_slice, axis=batch_axis)
n_each_section, extras = divmod(size, num_slice)
section_sizes = [0] + (extras * [n_each_section + 1] +
(num_slice - extras) * [n_each_section])
div_points = np.array(section_sizes).cumsum()
if is_np_array():
slices = _mx_np.split(data, indices_or_sections=list(div_points[1: -1]), axis=batch_axis)
else:
if is_np_array():
indices = [step * i for i in range(1, num_slice)]
slices = _mx_np.split(data, indices_or_sections=indices, axis=batch_axis)
else:
slices = [ndarray.slice_axis(data, batch_axis, i*step, (i+1)*step)
if i < num_slice - 1 else
ndarray.slice_axis(data, batch_axis, i*step, size)
for i in range(num_slice)]
slices = []
for i in range(num_slice):
st = div_points[i]
end = div_points[i + 1]
slices.append(ndarray.slice_axis(data, axis=batch_axis, begin=st, end=end))
return slices


Expand Down
32 changes: 28 additions & 4 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from mxnet.gluon import nn
from mxnet.base import py_str
from mxnet.test_utils import assert_almost_equal
from mxnet.util import is_np_array
from mxnet.ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID
from mxnet.test_utils import use_np
import mxnet.numpy as _mx_np
from common import (setup_module, with_seed, assertRaises, teardown,
assert_raises_cudnn_not_satisfied)
import numpy as np
Expand Down Expand Up @@ -952,17 +955,39 @@ def test_deferred_init():
layer(x)



def check_split_data(x, num_slice, batch_axis, **kwargs):
res = gluon.utils.split_data(x, num_slice, batch_axis, **kwargs)
assert len(res) == num_slice
mx.test_utils.assert_almost_equal(mx.nd.concat(*res, dim=batch_axis).asnumpy(),
x.asnumpy())
if not is_np_array():
mx.test_utils.assert_almost_equal(mx.nd.concat(*res, dim=batch_axis).asnumpy(),
x.asnumpy())
else:
mx.test_utils.assert_almost_equal(_mx_np.concatenate(res, axis=batch_axis).asnumpy(),
x.asnumpy())
np_res = np.array_split(x.asnumpy(), num_slice, axis=batch_axis)
res_asnp = [s.asnumpy() for s in res]
for r1, r2 in zip(np_res, res_asnp):
assert all(r1.reshape(-1) == r2.reshape(-1))


@with_seed()
@use_np
def test_split_data_np():
x = _mx_np.random.uniform(size=(128, 33, 64))
check_split_data(x, 8, 0)
check_split_data(x, 3, 1)
check_split_data(x, 4, 1, even_split=False)
check_split_data(x, 15, 1, even_split=False)
try:
check_split_data(x, 4, 1)
except ValueError:
return
assert False, "Should have failed"

@with_seed()
def test_split_data():
x = mx.nd.random.uniform(shape=(128, 33, 64))
zburning marked this conversation as resolved.
Show resolved Hide resolved

check_split_data(x, 8, 0)
check_split_data(x, 3, 1)
check_split_data(x, 4, 1, even_split=False)
Expand All @@ -973,7 +998,6 @@ def test_split_data():
return
assert False, "Should have failed"


@with_seed()
def test_flatten():
flatten = nn.Flatten()
Expand Down