diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index 51f0a5fd4da9..05ba061bc66c 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -77,10 +77,15 @@ def split_data(data, num_slice, batch_axis=0, even_split=True): slices = _mx_np.split(data, indices_or_sections=list(div_points[1: -1]), axis=batch_axis) else: slices = [] - for i in range(num_slice): - st = div_points[i] - end = div_points[i + 1] - slices.append(ndarray.slice_axis(data, axis=batch_axis, begin=st, end=end)) + if batch_axis != 0: + for i in range(num_slice): + st = div_points[i] + end = div_points[i + 1] + slices.append(ndarray.slice_axis(data, axis=batch_axis, begin=st, end=end)) + else: + # Fixes issue: https://github.com/apache/incubator-mxnet/issues/19268 + slices = [data[div_points[i]:div_points[i + 1]] if i < num_slice - 1 else data[div_points[i]:size] + for i in range(num_slice)] return slices diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 42252d52be2b..7bacb4f0b317 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -29,6 +29,8 @@ import mxnet.numpy as _mx_np from common import (setup_module, with_seed, assertRaises, teardown, assert_raises_cudnn_not_satisfied, environment) +import mxnet.ndarray.sparse as mxsps +import scipy.sparse as sps import numpy as np from numpy.testing import assert_array_equal from nose.tools import raises, assert_raises @@ -3229,6 +3231,15 @@ def hybrid_forward(self, F, x): mx.test_utils.assert_almost_equal(grad1, grad2) + +def test_split_and_load(): + ctx_list = (mx.cpu(0), mx.cpu(0)) + csr_arr = mxsps.csr_matrix(sps.coo_matrix(([2.0], ([99], [999]))).tocsr(), ctx=mx.cpu(0)) + arr_list = gluon.utils.split_and_load(csr_arr, ctx_list) + assert hasattr(arr_list[0], 'indices') + assert isinstance(arr_list[0], mxsps.CSRNDArray) + + if __name__ == '__main__': import nose nose.runmodule()