From 3aa843c8258793782d6cc3b81036556e5e9125f2 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 15 Aug 2018 03:47:13 +0000 Subject: [PATCH 1/2] add worker_fn argument to multiworker function --- python/mxnet/gluon/data/dataloader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index e0b6aec294a0..e7b5d091be70 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -183,7 +183,7 @@ def fetcher_loop(data_queue, data_buffer, pin_memory=False): class _MultiWorkerIter(object): """Interal multi-worker iterator for DataLoader.""" - def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory=False): + def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory=False, worker_fn=worker_loop): assert num_workers > 0, "_MultiWorkerIter is not for {} workers".format(num_workers) self._num_workers = num_workers self._dataset = dataset @@ -200,7 +200,7 @@ def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory= workers = [] for _ in range(self._num_workers): worker = multiprocessing.Process( - target=worker_loop, + target=worker_fn, args=(self._dataset, self._key_queue, self._data_queue, self._batchify_fn)) worker.daemon = True worker.start() From 9062d4097459c1575fedb85db979cd4d685bd991 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 15 Aug 2018 03:57:03 +0000 Subject: [PATCH 2/2] fix pylin --- python/mxnet/gluon/data/dataloader.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/mxnet/gluon/data/dataloader.py b/python/mxnet/gluon/data/dataloader.py index e7b5d091be70..412d3134476b 100644 --- a/python/mxnet/gluon/data/dataloader.py +++ b/python/mxnet/gluon/data/dataloader.py @@ -183,7 +183,8 @@ def fetcher_loop(data_queue, data_buffer, pin_memory=False): class _MultiWorkerIter(object): """Interal multi-worker iterator for DataLoader.""" - def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory=False, worker_fn=worker_loop): + def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, pin_memory=False, + worker_fn=worker_loop): assert num_workers > 0, "_MultiWorkerIter is not for {} workers".format(num_workers) self._num_workers = num_workers self._dataset = dataset