Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fixing breaking change introduced in #17123 when batch_axis=0 (#19267)
Browse files Browse the repository at this point in the history
Co-authored-by: Rohit Kumar Srivastava <[email protected]>
  • Loading branch information
access2rohit and Rohit Kumar Srivastava authored Oct 3, 2020
1 parent 16280ad commit a546260
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
13 changes: 9 additions & 4 deletions python/mxnet/gluon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,15 @@ def split_data(data, num_slice, batch_axis=0, even_split=True):
slices = _mx_np.split(data, indices_or_sections=list(div_points[1: -1]), axis=batch_axis)
else:
slices = []
for i in range(num_slice):
st = div_points[i]
end = div_points[i + 1]
slices.append(ndarray.slice_axis(data, axis=batch_axis, begin=st, end=end))
if batch_axis != 0:
for i in range(num_slice):
st = div_points[i]
end = div_points[i + 1]
slices.append(ndarray.slice_axis(data, axis=batch_axis, begin=st, end=end))
else:
# Fixes issue: https://github.com/apache/incubator-mxnet/issues/19268
slices = [data[div_points[i]:div_points[i + 1]] if i < num_slice - 1 else data[div_points[i]:size]
for i in range(num_slice)]
return slices


Expand Down
11 changes: 11 additions & 0 deletions tests/python/unittest/test_gluon.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
import mxnet.numpy as _mx_np
from common import (setup_module, with_seed, assertRaises, teardown,
assert_raises_cudnn_not_satisfied, environment)
import mxnet.ndarray.sparse as mxsps
import scipy.sparse as sps
import numpy as np
from numpy.testing import assert_array_equal
from nose.tools import raises, assert_raises
Expand Down Expand Up @@ -3229,6 +3231,15 @@ def hybrid_forward(self, F, x):

mx.test_utils.assert_almost_equal(grad1, grad2)


def test_split_and_load():
ctx_list = (mx.cpu(0), mx.cpu(0))
csr_arr = mxsps.csr_matrix(sps.coo_matrix(([2.0], ([99], [999]))).tocsr(), ctx=mx.cpu(0))
arr_list = gluon.utils.split_and_load(csr_arr, ctx_list)
assert hasattr(arr_list[0], 'indices')
assert isinstance(arr_list[0], mxsps.CSRNDArray)


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit a546260

Please sign in to comment.