Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[WIP]Attempt to root cause test failure in v1.x branch #19879

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 1 addition & 69 deletions python/mxnet/gluon/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,55 +572,11 @@ def default_batchify_fn(data):
unless you are experiencing timeout and you know it's due to slow data loading.
Sometimes full `shared_memory` will cause all workers to hang and causes timeout. In these
cases please reduce `num_workers` or increase system `shared_memory` size instead.
auto_reload : bool, default is True
control whether prefetch data after a batch is ended.

Example:
>>> from mxnet.gluon.data import DataLoader, ArrayDataset
>>> train_data = ArrayDataset([i for i in range(10)],[9-i for i in range(10)])
>>> def transform_train(sample):
... if sample == 0 : print('(pre)fetching data here')
... return sample
...
>>> train_iter = DataLoader(train_data.transform_first(transform_train),
... auto_reload=False, batch_size=1,num_workers=1)
>>> # no prefetch is performed, the prefetch & autoload start after
>>> # train_iter.__iter__() is called.
>>> for i in train_iter:pass
(pre)fetching data here
>>> train_iter = DataLoader(train_data.transform_first(transform_train),
... batch_size=1,num_workers=1)
(pre)fetching data here
>>> it = iter(train_iter) # nothing is generated since lazy-evaluation occurs
>>> it2 = iter(train_iter)
>>> it3 = iter(train_iter)
>>> it4 = iter(train_iter)
>>> _ = next(it2) # the first iter we are using is the prefetched iter.
>>> _ = next(it) # since the prefetched iter is consumed, we have to fetch data for `it`.
(pre)fetching data here
>>> _ = [None for _ in it3]
(pre)fetching data here
(pre)fetching data here
>>> # Here, 2 prefetches are triggered, one is fetching the first batch of `it3` and
>>> # another is when `it3` yield its last item, a prefetch is automatically performed.
>>> _ = [None for _ in it]
>>> # no prefetch is happened since train_loader has already prefetch data.
>>> _ = next(it4)
>>> # since the prefetch is performed, it4 become the prefetched iter.
>>>
>>> test_data = ArrayDataset([i for i in range(10)],[9-i for i in range(10)])
>>> test_iter = DataLoader(test_data, batch_size=1,num_workers=1)
>>> for epoch in range(200):
... # there is almost no difference between it and the default DataLoader
... for data, label in train_iter:
... # training...
... for data, label in test_iter:
... # testing...
"""
def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
last_batch=None, batch_sampler=None, batchify_fn=None,
num_workers=0, pin_memory=False, pin_device_id=0,
prefetch=None, thread_pool=False, timeout=120, auto_reload=False):
prefetch=None, thread_pool=False, timeout=120):
self._dataset = dataset
self._pin_memory = pin_memory
self._pin_device_id = pin_device_id
Expand Down Expand Up @@ -671,24 +627,8 @@ def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
self._batchify_fn = default_batchify_fn
else:
self._batchify_fn = batchify_fn
self.auto_reload = auto_reload
if self.auto_reload:
self.refresh()
else:
self.clean() # ensure self._iter exists.

def __iter__(self):
if self._iter is None:
self.refresh()
t = self._iter
self._iter = None # ensure a single iter would not using twice.
for item in t:
yield item
if self._iter is None and self.auto_reload:
# ensure we do not waste any exist iter by mistake
self.refresh()

def _prefetch_iter(self):
if self._num_workers == 0:
def same_process_iter():
for batch in self._batch_sampler:
Expand All @@ -715,11 +655,3 @@ def __del__(self):
# https://bugs.python.org/issue34172
assert isinstance(self._worker_pool, multiprocessing.pool.Pool)
self._worker_pool.terminate()

def refresh(self):
"""Refresh its iter, fetch data again from its dataset"""
self._iter = self._prefetch_iter()

def clean(self):
"""Remove its prefetched iter, the prefetch step will start after call its __iter__()"""
self._iter = None
11 changes: 4 additions & 7 deletions tests/python/unittest/test_gluon_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,10 @@ def __getitem__(self, key):
def test_multi_worker():
data = Dataset()
for thread_pool in [True, False]:
for auto_reload in [True, False]:
loader = gluon.data.DataLoader(data, batch_size=1, num_workers=5,
thread_pool=thread_pool,auto_reload=auto_reload)
for i, batch in enumerate(loader):
assert (batch.asnumpy() == i).all()
loader = gluon.data.DataLoader(data, batch_size=1, num_workers=5, thread_pool=thread_pool)
for i, batch in enumerate(loader):
assert (batch.asnumpy() == i).all()


@with_seed()
def test_multi_worker_shape():
Expand Down Expand Up @@ -251,7 +250,6 @@ def _batchify(data):
nd.array(y_lens, ctx=context.Context('cpu_shared', 0)))

@with_seed()
@unittest.skip("skipping flaky test - see https://github.com/apache/incubator-mxnet/issues/19877")
def test_multi_worker_forked_data_loader():
data = _Dummy(False)
loader = DataLoader(data, batch_size=40, batchify_fn=_batchify, num_workers=2)
Expand All @@ -266,7 +264,6 @@ def test_multi_worker_forked_data_loader():
pass

@with_seed()
@unittest.skip("skipping flaky test - see https://github.com/apache/incubator-mxnet/issues/19877")
def test_multi_worker_dataloader_release_pool():
# will trigger too many open file if pool is not released properly
if os.name == 'nt':
Expand Down