From 2fc4248550c325b02a76f67b1cec32161a32dc4f Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Fri, 10 Aug 2018 13:03:11 -0700 Subject: [PATCH] take custom dataset into consideration (#12093) --- python/mxnet/gluon/data/dataloader.py | 3 ++- tests/python/unittest/test_gluon_data.py | 7 +++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index 13ab544a03d3..e0b6aec294a0 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -160,7 +160,8 @@ def _as_in_context(data, ctx): def worker_loop(dataset, key_queue, data_queue, batchify_fn): """Worker loop for multiprocessing DataLoader.""" - dataset._fork() + if hasattr(dataset, '_fork') and callable(dataset._fork): + dataset._fork() while True: idx, samples = key_queue.get() if idx is None: diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py index 4dc4f3ac8819..53ce600629c8 100644 --- a/tests/python/unittest/test_gluon_data.py +++ b/tests/python/unittest/test_gluon_data.py @@ -116,6 +116,13 @@ def test_image_folder_dataset(): assert dataset.synsets == ['test_images'] assert len(dataset.items) == 16 +@with_seed() +def test_list_dataset(): + for num_worker in range(0, 3): + data = mx.gluon.data.DataLoader([([1,2], 0), ([3, 4], 1)], batch_size=1, num_workers=num_worker) + for d, l in data: + pass + class Dataset(gluon.data.Dataset): def __len__(self):