Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

split_and_load can now handle num_ctx > num_data. Github Issue #13909 #14607

Merged
merged 1 commit into from
Apr 8, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions python/mxnet/gluon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,6 @@ def split_data(data, num_slice, batch_axis=0, even_split=True):
Return value is a list even if `num_slice` is 1.
"""
size = data.shape[batch_axis]
if size < num_slice:
raise ValueError(
"Too many slices for data with shape %s. Arguments are " \
"num_slice=%d and batch_axis=%d."%(str(data.shape), num_slice, batch_axis))
if even_split and size % num_slice != 0:
raise ValueError(
"data with shape %s cannot be evenly split into %d slices along axis %d. " \
Expand All @@ -75,6 +71,12 @@ def split_data(data, num_slice, batch_axis=0, even_split=True):
str(data.shape), num_slice, batch_axis, num_slice))

step = size // num_slice

# If size < num_slice, make fewer slices
if not even_split and size < num_slice:
step = 1
num_slice = size

if batch_axis == 0:
slices = [data[i*step:(i+1)*step] if i < num_slice - 1 else data[i*step:size]
for i in range(num_slice)]
Expand Down