Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Fixed a bug in Gluon DataLoader. (#15195)
Browse files Browse the repository at this point in the history
* Fixed a bug in Gluon DataLoader.
    Issue: #15025
    Fix: Broadened the scope of worker pool to iterators. Passed a reference of dataloader to the multi worker iterator

* Fixed a bug in Gluon DataLoader.
    Issue: #15025
    Fix: Broadened the scope of worker pool to iterators. Passed a reference of dataloader to the multi worker iterator

* Fixed a bug in Gluon DataLoader.
    Issue: #15025
    Fix: Broadened the scope of worker pool to iterators. Passed a reference of dataloader to the multi worker iterator

* Fixed a bug in Gluon DataLoader.
    Issue: #15025
    Fix: Broadened the scope of worker pool to iterators. Passed a reference of dataloader to the multi worker iterator
  • Loading branch information
chandana1332 authored and wkcn committed Jun 12, 2019
1 parent 769b882 commit 2e20094
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
6 changes: 4 additions & 2 deletions python/mxnet/gluon/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def _thread_worker_fn(samples, batchify_fn, dataset):
class _MultiWorkerIter(object):
"""Internal multi-worker iterator for DataLoader."""
def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False,
pin_device_id=0, worker_fn=_worker_fn, prefetch=0, dataset=None):
pin_device_id=0, worker_fn=_worker_fn, prefetch=0, dataset=None, data_loader=None):
self._worker_pool = worker_pool
self._batchify_fn = batchify_fn
self._batch_sampler = batch_sampler
Expand All @@ -421,6 +421,7 @@ def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False,
self._pin_memory = pin_memory
self._pin_device_id = pin_device_id
self._dataset = dataset
self._data_loader = data_loader
# pre-fetch
for _ in range(prefetch):
self._push_next()
Expand Down Expand Up @@ -582,7 +583,8 @@ def same_process_iter():
pin_memory=self._pin_memory, pin_device_id=self._pin_device_id,
worker_fn=_thread_worker_fn if self._thread_pool else _worker_fn,
prefetch=self._prefetch,
dataset=self._dataset if self._thread_pool else None)
dataset=self._dataset if self._thread_pool else None,
data_loader=self)

def __len__(self):
return len(self._batch_sampler)
Expand Down
25 changes: 25 additions & 0 deletions tests/python/unittest/test_gluon_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import mxnet.ndarray as nd
from mxnet import context
from mxnet.gluon.data.dataset import Dataset
from mxnet.gluon.data.dataset import ArrayDataset

@with_seed()
def test_array_dataset():
Expand Down Expand Up @@ -279,6 +280,30 @@ def test_dataloader_context():
for _, x in enumerate(loader3):
assert x.context == context.cpu_pinned(custom_dev_id)

def batchify(a):
return a

def test_dataloader_scope():
"""
Bug: Gluon DataLoader terminates the process pool early while
_MultiWorkerIter is operating on the pool.
Tests that DataLoader is not garbage collected while the iterator is
in use.
"""
args = {'num_workers': 1, 'batch_size': 2}
dataset = nd.ones(5)
iterator = iter(DataLoader(
dataset,
batchify_fn=batchify,
**args
)
)

item = next(iterator)

assert item is not None


if __name__ == '__main__':
import nose
Expand Down

0 comments on commit 2e20094

Please sign in to comment.