From 8b6d53a0d65883b92d5b5c0bc98a64aa03a3e337 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 2 Sep 2019 23:34:07 +0000 Subject: [PATCH 1/7] first commit --- python/mxnet/gluon/data/dataset.py | 67 ++++++++++++++++++++++++ tests/python/unittest/test_gluon_data.py | 25 +++++++++ 2 files changed, 92 insertions(+) diff --git a/python/mxnet/gluon/data/dataset.py b/python/mxnet/gluon/data/dataset.py index 28d19c9fe37c..58ac500e9b34 100644 --- a/python/mxnet/gluon/data/dataset.py +++ b/python/mxnet/gluon/data/dataset.py @@ -40,6 +40,41 @@ def __getitem__(self, idx): def __len__(self): raise NotImplementedError + def filter(self, fn): + """Returns a new dataset with samples filtered by the + filter function `fn`. + + Parameters + ---------- + fn : callable + A filter function that takes a sample as input and + returns a boolean. Samples that return False are discarded. + + Returns + ------- + Dataset + The filtered dataset. + """ + return _FilteredDataset(self, fn) + + def take(self, count): + """Returns a new dataset with at most `count` number of samples in it. + + Parameters + ---------- + count : int or None + A integer representing the number of elements of this dataset that + should be taken to form the new dataset. If count is None, or if count + is greater than the size of this dataset, the new dataset will contain + all elements of this dataset. + + Returns + ------- + Dataset + The result dataset. + """ + return _TakenDataset(self, count) + def transform(self, fn, lazy=True): """Returns a new dataset with each sample transformed by the transformer function `fn`. @@ -135,6 +170,38 @@ def __call__(self, x, *args): return (self._fn(x),) + args return self._fn(x) +class _FilteredDataset(Dataset): + """Dataset with a filter applied""" + def __init__(self, dataset, fn): + self._dataset = dataset + self._indices = [] + for i in range(len(dataset)): + if fn(dataset[i]): + self._indices.append(i) + + def __len__(self): + return len(self._indices) + + def __getitem__(self, idx): + return self._dataset[self._indices[idx]] + +class _TakenDataset(Dataset): + """Dataset with at most `count` samples""" + def __init__(self, dataset, count): + self._dataset = dataset + self._count = count + if count is None or count > len(dataset): + self._count = len(dataset) + assert self._count >= 0 + + def __len__(self): + return self._count + + def __getitem__(self, idx): + if idx >= len(self): + raise IndexError('Invalid index %d. Expected index < %d'%(idx, len(self))) + return self._dataset[idx] + class ArrayDataset(Dataset): """A dataset that combines multiple dataset-like objects, e.g. Datasets, lists, arrays, etc. diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py index 58e241b528ea..54b48856ad99 100644 --- a/tests/python/unittest/test_gluon_data.py +++ b/tests/python/unittest/test_gluon_data.py @@ -283,6 +283,31 @@ def test_dataloader_context(): def batchify(a): return a +def test_dataset_filter(): + length = 100 + a = mx.gluon.data.SimpleDataset([i for i in range(length)]) + a_filtered = a.filter(lambda x: x % 10 == 0) + assert(len(a_filtered) == 10) + for idx, sample in enumerate(gluon.data.DataLoader(a_filtered, 1)): + assert sample % 10 == 0 + +def test_dataset_take(): + length = 100 + a = mx.gluon.data.SimpleDataset([i for i in range(length)]) + a_take_full = a.take(1000) + assert len(a_take_full) == length + a_take_full = a.take(None) + assert len(a_take_full) == length + count = 10 + a_take_10 = a.take(count) + assert len(a_take_10) == count + expected_total = sum([i for i in range(count)]) + total = 0 + for idx, sample in enumerate(a_take_10):#gluon.data.DataLoader(a_take_10, 1)): + assert sample < count + total += sample + assert total == expected_total + def test_dataloader_scope(): """ Bug: Gluon DataLoader terminates the process pool early while From d0b7ed277798ce03c5e83903195916c657dac250 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 2 Sep 2019 23:57:27 +0000 Subject: [PATCH 2/7] fix lint --- python/mxnet/gluon/data/dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/gluon/data/dataset.py b/python/mxnet/gluon/data/dataset.py index 58ac500e9b34..1606bb4cddb3 100644 --- a/python/mxnet/gluon/data/dataset.py +++ b/python/mxnet/gluon/data/dataset.py @@ -175,8 +175,8 @@ class _FilteredDataset(Dataset): def __init__(self, dataset, fn): self._dataset = dataset self._indices = [] - for i in range(len(dataset)): - if fn(dataset[i]): + for i, sample in enumerate(dataset): + if fn(sample): self._indices.append(i) def __len__(self): From 097f5fa1c3ba7bf7fc3de6e988ef288af87a2d3d Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Tue, 3 Sep 2019 03:42:01 +0000 Subject: [PATCH 3/7] add more test for transformed data --- tests/python/unittest/test_gluon_data.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py index 54b48856ad99..62bf7c3db869 100644 --- a/tests/python/unittest/test_gluon_data.py +++ b/tests/python/unittest/test_gluon_data.py @@ -288,7 +288,12 @@ def test_dataset_filter(): a = mx.gluon.data.SimpleDataset([i for i in range(length)]) a_filtered = a.filter(lambda x: x % 10 == 0) assert(len(a_filtered) == 10) - for idx, sample in enumerate(gluon.data.DataLoader(a_filtered, 1)): + for idx, sample in enumerate(a_filtered): + assert sample % 10 == 0 + a_xform_filtered = a.transform(lambda x: x + 1).filter(lambda x: x % 10 == 0) + assert(len(a_xform_filtered) == 10) + # the filtered data is already transformed + for idx, sample in enumerate(a_xform_filtered): assert sample % 10 == 0 def test_dataset_take(): @@ -303,11 +308,20 @@ def test_dataset_take(): assert len(a_take_10) == count expected_total = sum([i for i in range(count)]) total = 0 - for idx, sample in enumerate(a_take_10):#gluon.data.DataLoader(a_take_10, 1)): + for idx, sample in enumerate(a_take_10): assert sample < count total += sample assert total == expected_total + a_xform_take_10 = a.transform(lambda x: x * 10).take(count) + assert len(a_xform_take_10) == count + expected_total = sum([i * 10 for i in range(count)]) + total = 0 + for idx, sample in enumerate(a_xform_take_10): + assert sample < count * 10 + total += sample + assert total == expected_total + def test_dataloader_scope(): """ Bug: Gluon DataLoader terminates the process pool early while From ffc786102c1e344bcbd809f3ad7cc870f2e90c76 Mon Sep 17 00:00:00 2001 From: Haibin Lin Date: Tue, 3 Sep 2019 12:10:41 -0700 Subject: [PATCH 4/7] address CR --- python/mxnet/gluon/data/dataset.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/mxnet/gluon/data/dataset.py b/python/mxnet/gluon/data/dataset.py index 1606bb4cddb3..6e7c3ef08a37 100644 --- a/python/mxnet/gluon/data/dataset.py +++ b/python/mxnet/gluon/data/dataset.py @@ -44,6 +44,12 @@ def filter(self, fn): """Returns a new dataset with samples filtered by the filter function `fn`. + Note that if the Dataset is the result of a lazily transformed one with + transform(lazy=False), the filter is eagerly applied to the transformed + samples without materializing the transformed result. That is, the + transformation will be applied again whenever a sample is retrieved after + filter(). + Parameters ---------- fn : callable @@ -174,10 +180,7 @@ class _FilteredDataset(Dataset): """Dataset with a filter applied""" def __init__(self, dataset, fn): self._dataset = dataset - self._indices = [] - for i, sample in enumerate(dataset): - if fn(sample): - self._indices.append(i) + self._indices = [i for for i, sample in enumerate(dataset) if fn(sample)] def __len__(self): return len(self._indices) From ab094f74131a4d2e0caf5433cc8d82124c718405 Mon Sep 17 00:00:00 2001 From: Haibin Lin Date: Tue, 3 Sep 2019 14:11:23 -0700 Subject: [PATCH 5/7] Update dataset.py --- python/mxnet/gluon/data/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/gluon/data/dataset.py b/python/mxnet/gluon/data/dataset.py index 6e7c3ef08a37..9ae6619ec8a6 100644 --- a/python/mxnet/gluon/data/dataset.py +++ b/python/mxnet/gluon/data/dataset.py @@ -180,7 +180,7 @@ class _FilteredDataset(Dataset): """Dataset with a filter applied""" def __init__(self, dataset, fn): self._dataset = dataset - self._indices = [i for for i, sample in enumerate(dataset) if fn(sample)] + self._indices = [i for i, sample in enumerate(dataset) if fn(sample)] def __len__(self): return len(self._indices) From fcfb4acff71cc1f3887b38629595371c66c2c88d Mon Sep 17 00:00:00 2001 From: Haibin Lin Date: Fri, 6 Sep 2019 20:47:50 -0700 Subject: [PATCH 6/7] use sampler to implement --- python/mxnet/gluon/data/dataset.py | 45 +++++++++++++++++++++--------- python/mxnet/gluon/data/sampler.py | 23 ++++++++++++++- 2 files changed, 54 insertions(+), 14 deletions(-) diff --git a/python/mxnet/gluon/data/dataset.py b/python/mxnet/gluon/data/dataset.py index 9ae6619ec8a6..5b81a6251a89 100644 --- a/python/mxnet/gluon/data/dataset.py +++ b/python/mxnet/gluon/data/dataset.py @@ -61,7 +61,8 @@ def filter(self, fn): Dataset The filtered dataset. """ - return _FilteredDataset(self, fn) + from . import FilterSampler + return _SampledDataset(self, FilterSampler(fn, self)) def take(self, count): """Returns a new dataset with at most `count` number of samples in it. @@ -79,7 +80,29 @@ def take(self, count): Dataset The result dataset. """ - return _TakenDataset(self, count) + if count is None or count > len(self): + count = len(self) + from . import SequentialSampler + return _SampledDataset(self, SequentialSampler(count)) + + def sample(self, sampler): + """Returns a new dataset with elements sampled by the sampler. + + Parameters + ---------- + sampler : Sampler + A Sampler that returns the indices of sampled elements. + + Returns + ------- + Dataset + The result dataset. + """ + from . import Sampler + if not isinstance(sampler, Sampler): + raise TypeError('Invalid sampler type: %s. Expected gluon.data.Sampler instead.'% + type(sampler)) + return _SampledDataset(self, sampler) def transform(self, fn, lazy=True): """Returns a new dataset with each sample transformed by the @@ -188,22 +211,18 @@ def __len__(self): def __getitem__(self, idx): return self._dataset[self._indices[idx]] -class _TakenDataset(Dataset): - """Dataset with at most `count` samples""" - def __init__(self, dataset, count): +class _SampledDataset(Dataset): + """Dataset with elements chosen by a sampler""" + def __init__(self, dataset, sampler): self._dataset = dataset - self._count = count - if count is None or count > len(dataset): - self._count = len(dataset) - assert self._count >= 0 + self._sampler = sampler + self._indices = list(iter(sampler)) def __len__(self): - return self._count + return len(self._sampler) def __getitem__(self, idx): - if idx >= len(self): - raise IndexError('Invalid index %d. Expected index < %d'%(idx, len(self))) - return self._dataset[idx] + return self._dataset[self._indices[idx]] class ArrayDataset(Dataset): """A dataset that combines multiple dataset-like objects, e.g. diff --git a/python/mxnet/gluon/data/sampler.py b/python/mxnet/gluon/data/sampler.py index 2f827c83bc4d..980c63134fac 100644 --- a/python/mxnet/gluon/data/sampler.py +++ b/python/mxnet/gluon/data/sampler.py @@ -18,7 +18,7 @@ # coding: utf-8 # pylint: disable= """Dataset sampler.""" -__all__ = ['Sampler', 'SequentialSampler', 'RandomSampler', 'BatchSampler'] +__all__ = ['Sampler', 'SequentialSampler', 'RandomSampler', 'FilterSampler', 'BatchSampler'] import numpy as np @@ -72,6 +72,27 @@ def __iter__(self): def __len__(self): return self._length +class FilterSampler(Sampler): + """Samples elements from a Dataset for which `fn` returns True. + + Parameters + ---------- + fn : callable + A callable function that takes a sample and returns a boolean + dataset : Dataset + The dataset to filter. + """ + def __init__(self, fn, dataset): + self._fn = fn + self._dataset = dataset + self._indices = [i for i, sample in enumerate(dataset) if fn(sample)] + + def __iter__(self): + return iter(self._indices) + + def __len__(self): + return len(self._indices) + class BatchSampler(Sampler): """Wraps over another `Sampler` and return mini-batches of samples. From d75a131ec42327ad4dcf0a734be4392c93bc8255 Mon Sep 17 00:00:00 2001 From: Haibin Lin Date: Sat, 7 Sep 2019 16:38:30 -0700 Subject: [PATCH 7/7] Update dataset.py --- python/mxnet/gluon/data/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/gluon/data/dataset.py b/python/mxnet/gluon/data/dataset.py index 5b81a6251a89..a5399f259435 100644 --- a/python/mxnet/gluon/data/dataset.py +++ b/python/mxnet/gluon/data/dataset.py @@ -101,7 +101,7 @@ def sample(self, sampler): from . import Sampler if not isinstance(sampler, Sampler): raise TypeError('Invalid sampler type: %s. Expected gluon.data.Sampler instead.'% - type(sampler)) + type(sampler)) return _SampledDataset(self, sampler) def transform(self, fn, lazy=True):