diff --git a/python/mxnet/gluon/utils.py b/python/mxnet/gluon/utils.py index 55edd950d223..b00cc043d493 100644 --- a/python/mxnet/gluon/utils.py +++ b/python/mxnet/gluon/utils.py @@ -63,10 +63,6 @@ def split_data(data, num_slice, batch_axis=0, even_split=True): Return value is a list even if `num_slice` is 1. """ size = data.shape[batch_axis] - if size < num_slice: - raise ValueError( - "Too many slices for data with shape %s. Arguments are " \ - "num_slice=%d and batch_axis=%d."%(str(data.shape), num_slice, batch_axis)) if even_split and size % num_slice != 0: raise ValueError( "data with shape %s cannot be evenly split into %d slices along axis %d. " \ @@ -75,6 +71,12 @@ def split_data(data, num_slice, batch_axis=0, even_split=True): str(data.shape), num_slice, batch_axis, num_slice)) step = size // num_slice + + # If size < num_slice, make fewer slices + if not even_split and size < num_slice: + step = 1 + num_slice = size + if batch_axis == 0: slices = [data[i*step:(i+1)*step] if i < num_slice - 1 else data[i*step:size] for i in range(num_slice)]