From 33a06a57c51b0fdbec29905416268217e530bb0e Mon Sep 17 00:00:00 2001 From: fiercex Date: Sat, 21 Sep 2019 15:17:23 +0800 Subject: [PATCH 1/3] fix dataloader --- python/mxnet/gluon/data/dataloader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index 68412267da6b..66aed967f7fb 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -475,7 +475,6 @@ def __next__(self): batch = ret.get(self._timeout) if self._pin_memory: batch = _as_in_context(batch, context.cpu_pinned(self._pin_device_id)) - batch = batch[0] if len(batch) == 1 else batch self._rcvd_idx += 1 return batch except multiprocessing.context.TimeoutError: From c2c7bbbc50bcf41a18dafc968d1b86ebfd420b9e Mon Sep 17 00:00:00 2001 From: fiercex Date: Sat, 12 Oct 2019 11:30:31 +0800 Subject: [PATCH 2/3] add unittest --- tests/python/unittest/test_gluon_data.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py index f87228739c36..80ffaaa10bae 100644 --- a/tests/python/unittest/test_gluon_data.py +++ b/tests/python/unittest/test_gluon_data.py @@ -162,6 +162,23 @@ def test_multi_worker(): for i, batch in enumerate(loader): assert (batch.asnumpy() == i).all() + +@with_seed() +def test_multi_worker_shape(): + batch_size = 1024 + shape = (batch_size+1, 11, 12) + + data = ArrayDataset(np.ones(shape)) + for thread_pool in [True, False]: + loader = gluon.data.DataLoader( + data, batch_size=batch_size, num_workers=5, last_batch='keep', thread_pool=thread_pool) + for batch in loader: + if shape[0] > batch_size: + assert batch.shape == (batch_size, shape[1], shape[2]) + shape = (shape[0] - batch_size, shape[1], shape[2]) + else: + assert batch.shape == shape + class _Dummy(Dataset): """Dummy dataset for randomized shape arrays.""" def __init__(self, random_shape): From 86c308a5505999e5447671b6c9196d45837441bc Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Tue, 5 Nov 2019 17:32:52 +0800 Subject: [PATCH 3/3] fix test --- tests/python/unittest/test_gluon_data.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py index 80ffaaa10bae..c3ae2de41722 100644 --- a/tests/python/unittest/test_gluon_data.py +++ b/tests/python/unittest/test_gluon_data.py @@ -165,11 +165,11 @@ def test_multi_worker(): @with_seed() def test_multi_worker_shape(): - batch_size = 1024 - shape = (batch_size+1, 11, 12) - - data = ArrayDataset(np.ones(shape)) for thread_pool in [True, False]: + batch_size = 1024 + shape = (batch_size+1, 11, 12) + + data = ArrayDataset(np.ones(shape)) loader = gluon.data.DataLoader( data, batch_size=batch_size, num_workers=5, last_batch='keep', thread_pool=thread_pool) for batch in loader: