diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index cb6ae5a164b9..7394f7d5d3d9 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -507,7 +507,6 @@ def __next__(self): batch = ret.get(self._timeout) if self._pin_memory: batch = _as_in_context(batch, context.cpu_pinned(self._pin_device_id)) - batch = batch[0] if len(batch) == 1 else batch self._rcvd_idx += 1 return batch except multiprocessing.context.TimeoutError: diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py index f87228739c36..c3ae2de41722 100644 --- a/tests/python/unittest/test_gluon_data.py +++ b/tests/python/unittest/test_gluon_data.py @@ -162,6 +162,23 @@ def test_multi_worker(): for i, batch in enumerate(loader): assert (batch.asnumpy() == i).all() + +@with_seed() +def test_multi_worker_shape(): + for thread_pool in [True, False]: + batch_size = 1024 + shape = (batch_size+1, 11, 12) + + data = ArrayDataset(np.ones(shape)) + loader = gluon.data.DataLoader( + data, batch_size=batch_size, num_workers=5, last_batch='keep', thread_pool=thread_pool) + for batch in loader: + if shape[0] > batch_size: + assert batch.shape == (batch_size, shape[1], shape[2]) + shape = (shape[0] - batch_size, shape[1], shape[2]) + else: + assert batch.shape == shape + class _Dummy(Dataset): """Dummy dataset for randomized shape arrays.""" def __init__(self, random_shape):