diff --git a/python/mxnet/gluon/data/dataset.py b/python/mxnet/gluon/data/dataset.py index 28d19c9fe37c..a5399f259435 100644 --- a/python/mxnet/gluon/data/dataset.py +++ b/python/mxnet/gluon/data/dataset.py @@ -40,6 +40,70 @@ def __getitem__(self, idx): def __len__(self): raise NotImplementedError + def filter(self, fn): + """Returns a new dataset with samples filtered by the + filter function `fn`. + + Note that if the Dataset is the result of a lazily transformed one with + transform(lazy=False), the filter is eagerly applied to the transformed + samples without materializing the transformed result. That is, the + transformation will be applied again whenever a sample is retrieved after + filter(). + + Parameters + ---------- + fn : callable + A filter function that takes a sample as input and + returns a boolean. Samples that return False are discarded. + + Returns + ------- + Dataset + The filtered dataset. + """ + from . import FilterSampler + return _SampledDataset(self, FilterSampler(fn, self)) + + def take(self, count): + """Returns a new dataset with at most `count` number of samples in it. + + Parameters + ---------- + count : int or None + A integer representing the number of elements of this dataset that + should be taken to form the new dataset. If count is None, or if count + is greater than the size of this dataset, the new dataset will contain + all elements of this dataset. + + Returns + ------- + Dataset + The result dataset. + """ + if count is None or count > len(self): + count = len(self) + from . import SequentialSampler + return _SampledDataset(self, SequentialSampler(count)) + + def sample(self, sampler): + """Returns a new dataset with elements sampled by the sampler. + + Parameters + ---------- + sampler : Sampler + A Sampler that returns the indices of sampled elements. + + Returns + ------- + Dataset + The result dataset. + """ + from . import Sampler + if not isinstance(sampler, Sampler): + raise TypeError('Invalid sampler type: %s. Expected gluon.data.Sampler instead.'% + type(sampler)) + return _SampledDataset(self, sampler) + def transform(self, fn, lazy=True): """Returns a new dataset with each sample transformed by the transformer function `fn`. @@ -135,6 +199,31 @@ def __call__(self, x, *args): return (self._fn(x),) + args return self._fn(x) +class _FilteredDataset(Dataset): + """Dataset with a filter applied""" + def __init__(self, dataset, fn): + self._dataset = dataset + self._indices = [i for i, sample in enumerate(dataset) if fn(sample)] + + def __len__(self): + return len(self._indices) + + def __getitem__(self, idx): + return self._dataset[self._indices[idx]] + +class _SampledDataset(Dataset): + """Dataset with elements chosen by a sampler""" + def __init__(self, dataset, sampler): + self._dataset = dataset + self._sampler = sampler + self._indices = list(iter(sampler)) + + def __len__(self): + return len(self._sampler) + + def __getitem__(self, idx): + return self._dataset[self._indices[idx]] + class ArrayDataset(Dataset): """A dataset that combines multiple dataset-like objects, e.g. Datasets, lists, arrays, etc. diff --git a/python/mxnet/gluon/data/sampler.py b/python/mxnet/gluon/data/sampler.py index 2f827c83bc4d..980c63134fac 100644 --- a/python/mxnet/gluon/data/sampler.py +++ b/python/mxnet/gluon/data/sampler.py @@ -18,7 +18,7 @@ # coding: utf-8 # pylint: disable= """Dataset sampler.""" -__all__ = ['Sampler', 'SequentialSampler', 'RandomSampler', 'BatchSampler'] +__all__ = ['Sampler', 'SequentialSampler', 'RandomSampler', 'FilterSampler', 'BatchSampler'] import numpy as np @@ -72,6 +72,27 @@ def __iter__(self): def __len__(self): return self._length +class FilterSampler(Sampler): + """Samples elements from a Dataset for which `fn` returns True. + + Parameters + ---------- + fn : callable + A callable function that takes a sample and returns a boolean + dataset : Dataset + The dataset to filter. + """ + def __init__(self, fn, dataset): + self._fn = fn + self._dataset = dataset + self._indices = [i for i, sample in enumerate(dataset) if fn(sample)] + + def __iter__(self): + return iter(self._indices) + + def __len__(self): + return len(self._indices) + class BatchSampler(Sampler): """Wraps over another `Sampler` and return mini-batches of samples. diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py index 58e241b528ea..62bf7c3db869 100644 --- a/tests/python/unittest/test_gluon_data.py +++ b/tests/python/unittest/test_gluon_data.py @@ -283,6 +283,45 @@ def test_dataloader_context(): def batchify(a): return a +def test_dataset_filter(): + length = 100 + a = mx.gluon.data.SimpleDataset([i for i in range(length)]) + a_filtered = a.filter(lambda x: x % 10 == 0) + assert(len(a_filtered) == 10) + for idx, sample in enumerate(a_filtered): + assert sample % 10 == 0 + a_xform_filtered = a.transform(lambda x: x + 1).filter(lambda x: x % 10 == 0) + assert(len(a_xform_filtered) == 10) + # the filtered data is already transformed + for idx, sample in enumerate(a_xform_filtered): + assert sample % 10 == 0 + +def test_dataset_take(): + length = 100 + a = mx.gluon.data.SimpleDataset([i for i in range(length)]) + a_take_full = a.take(1000) + assert len(a_take_full) == length + a_take_full = a.take(None) + assert len(a_take_full) == length + count = 10 + a_take_10 = a.take(count) + assert len(a_take_10) == count + expected_total = sum([i for i in range(count)]) + total = 0 + for idx, sample in enumerate(a_take_10): + assert sample < count + total += sample + assert total == expected_total + + a_xform_take_10 = a.transform(lambda x: x * 10).take(count) + assert len(a_xform_take_10) == count + expected_total = sum([i * 10 for i in range(count)]) + total = 0 + for idx, sample in enumerate(a_xform_take_10): + assert sample < count * 10 + total += sample + assert total == expected_total + def test_dataloader_scope(): """ Bug: Gluon DataLoader terminates the process pool early while