diff --git a/python/mxnet/gluon/data/dataset.py b/python/mxnet/gluon/data/dataset.py index c93a4b1cd6b9..28d19c9fe37c 100644 --- a/python/mxnet/gluon/data/dataset.py +++ b/python/mxnet/gluon/data/dataset.py @@ -88,11 +88,7 @@ def transform_first(self, fn, lazy=True): Dataset The transformed dataset. """ - def base_fn(x, *args): - if args: - return (fn(x),) + args - return fn(x) - return self.transform(base_fn, lazy) + return self.transform(_TransformFirstClosure(fn), lazy) class SimpleDataset(Dataset): @@ -129,6 +125,16 @@ def __getitem__(self, idx): return self._fn(item) +class _TransformFirstClosure(object): + """Use callable object instead of nested function, it can be pickled.""" + def __init__(self, fn): + self._fn = fn + + def __call__(self, x, *args): + if args: + return (self._fn(x),) + args + return self._fn(x) + class ArrayDataset(Dataset): """A dataset that combines multiple dataset-like objects, e.g. Datasets, lists, arrays, etc. diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py index 6a5322616e20..353a819ddbf6 100644 --- a/tests/python/unittest/test_gluon_data.py +++ b/tests/python/unittest/test_gluon_data.py @@ -77,6 +77,10 @@ def _dataset_transform_fn(x, y): """Named transform function since lambda function cannot be pickled.""" return x, y +def _dataset_transform_first_fn(x): + """Named transform function since lambda function cannot be pickled.""" + return x + @with_seed() def test_recordimage_dataset_with_data_loader_multiworker(): recfile = prepare_record() @@ -95,17 +99,13 @@ def test_recordimage_dataset_with_data_loader_multiworker(): assert x.shape[0] == 1 and x.shape[3] == 3 assert y.asscalar() == i - # try limit recursion depth - import sys - old_limit = sys.getrecursionlimit() - sys.setrecursionlimit(500) # this should be smaller than any default value used in python - dataset = gluon.data.vision.ImageRecordDataset(recfile).transform(_dataset_transform_fn) + # with transform_first + dataset = gluon.data.vision.ImageRecordDataset(recfile).transform_first(_dataset_transform_first_fn) loader = gluon.data.DataLoader(dataset, 1, num_workers=5) for i, (x, y) in enumerate(loader): assert x.shape[0] == 1 and x.shape[3] == 3 assert y.asscalar() == i - sys.setrecursionlimit(old_limit) @with_seed() def test_sampler():