diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index eb1eb419cd02..438224f45f1d 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -151,7 +151,8 @@ def default_mp_batchify_fn(data): def worker_loop(dataset, key_queue, data_queue, batchify_fn): """Worker loop for multiprocessing DataLoader.""" - dataset._fork() + if hasattr(dataset, '_fork') and callable(dataset._fork): + dataset._fork() while True: idx, samples = key_queue.get() if idx is None: diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py index 043804487b5e..0550a5481255 100644 --- a/tests/python/unittest/test_gluon_data.py +++ b/tests/python/unittest/test_gluon_data.py @@ -116,6 +116,13 @@ def test_image_folder_dataset(): assert dataset.synsets == ['test_images'] assert len(dataset.items) == 16 +@with_seed() +def test_list_dataset(): + for num_worker in range(0, 3): + data = mx.gluon.data.DataLoader([([1,2], 0), ([3, 4], 1)], batch_size=1, num_workers=num_worker) + for d, l in data: + pass + class Dataset(gluon.data.Dataset): def __len__(self):