Skip to content

Commit

Permalink
fix unpicklable transform_first on windows (apache#13686)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhreshold authored and eric-haibin-lin committed Dec 19, 2018
1 parent e27e3e9 commit af7905f
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
16 changes: 11 additions & 5 deletions python/mxnet/gluon/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,7 @@ def transform_first(self, fn, lazy=True):
Dataset
The transformed dataset.
"""
def base_fn(x, *args):
if args:
return (fn(x),) + args
return fn(x)
return self.transform(base_fn, lazy)
return self.transform(_TransformFirstClosure(fn), lazy)


class SimpleDataset(Dataset):
Expand Down Expand Up @@ -129,6 +125,16 @@ def __getitem__(self, idx):
return self._fn(item)


class _TransformFirstClosure(object):
"""Use callable object instead of nested function, it can be pickled."""
def __init__(self, fn):
self._fn = fn

def __call__(self, x, *args):
if args:
return (self._fn(x),) + args
return self._fn(x)

class ArrayDataset(Dataset):
"""A dataset that combines multiple dataset-like objects, e.g.
Datasets, lists, arrays, etc.
Expand Down
12 changes: 6 additions & 6 deletions tests/python/unittest/test_gluon_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ def _dataset_transform_fn(x, y):
"""Named transform function since lambda function cannot be pickled."""
return x, y

def _dataset_transform_first_fn(x):
"""Named transform function since lambda function cannot be pickled."""
return x

@with_seed()
def test_recordimage_dataset_with_data_loader_multiworker():
recfile = prepare_record()
Expand All @@ -95,17 +99,13 @@ def test_recordimage_dataset_with_data_loader_multiworker():
assert x.shape[0] == 1 and x.shape[3] == 3
assert y.asscalar() == i

# try limit recursion depth
import sys
old_limit = sys.getrecursionlimit()
sys.setrecursionlimit(500) # this should be smaller than any default value used in python
dataset = gluon.data.vision.ImageRecordDataset(recfile).transform(_dataset_transform_fn)
# with transform_first
dataset = gluon.data.vision.ImageRecordDataset(recfile).transform_first(_dataset_transform_first_fn)
loader = gluon.data.DataLoader(dataset, 1, num_workers=5)

for i, (x, y) in enumerate(loader):
assert x.shape[0] == 1 and x.shape[3] == 3
assert y.asscalar() == i
sys.setrecursionlimit(old_limit)

@with_seed()
def test_sampler():
Expand Down

0 comments on commit af7905f

Please sign in to comment.