Skip to content

Commit

Permalink
[Dataset] Add take, filter, sample API to dataset (apache#16078)
Browse files Browse the repository at this point in the history
* first commit

* fix lint

* add more test  for transformed data

* address CR

* Update dataset.py

* use sampler to implement

* Update dataset.py
  • Loading branch information
eric-haibin-lin authored and larroy committed Sep 28, 2019
1 parent 0220d4f commit 74d69cd
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 1 deletion.
89 changes: 89 additions & 0 deletions python/mxnet/gluon/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,70 @@ def __getitem__(self, idx):
def __len__(self):
raise NotImplementedError

def filter(self, fn):
"""Returns a new dataset with samples filtered by the
filter function `fn`.
Note that if the Dataset is the result of a lazily transformed one with
transform(lazy=False), the filter is eagerly applied to the transformed
samples without materializing the transformed result. That is, the
transformation will be applied again whenever a sample is retrieved after
filter().
Parameters
----------
fn : callable
A filter function that takes a sample as input and
returns a boolean. Samples that return False are discarded.
Returns
-------
Dataset
The filtered dataset.
"""
from . import FilterSampler
return _SampledDataset(self, FilterSampler(fn, self))

def take(self, count):
"""Returns a new dataset with at most `count` number of samples in it.
Parameters
----------
count : int or None
A integer representing the number of elements of this dataset that
should be taken to form the new dataset. If count is None, or if count
is greater than the size of this dataset, the new dataset will contain
all elements of this dataset.
Returns
-------
Dataset
The result dataset.
"""
if count is None or count > len(self):
count = len(self)
from . import SequentialSampler
return _SampledDataset(self, SequentialSampler(count))

def sample(self, sampler):
"""Returns a new dataset with elements sampled by the sampler.
Parameters
----------
sampler : Sampler
A Sampler that returns the indices of sampled elements.
Returns
-------
Dataset
The result dataset.
"""
from . import Sampler
if not isinstance(sampler, Sampler):
raise TypeError('Invalid sampler type: %s. Expected gluon.data.Sampler instead.'%
type(sampler))
return _SampledDataset(self, sampler)

def transform(self, fn, lazy=True):
"""Returns a new dataset with each sample transformed by the
transformer function `fn`.
Expand Down Expand Up @@ -135,6 +199,31 @@ def __call__(self, x, *args):
return (self._fn(x),) + args
return self._fn(x)

class _FilteredDataset(Dataset):
"""Dataset with a filter applied"""
def __init__(self, dataset, fn):
self._dataset = dataset
self._indices = [i for i, sample in enumerate(dataset) if fn(sample)]

def __len__(self):
return len(self._indices)

def __getitem__(self, idx):
return self._dataset[self._indices[idx]]

class _SampledDataset(Dataset):
"""Dataset with elements chosen by a sampler"""
def __init__(self, dataset, sampler):
self._dataset = dataset
self._sampler = sampler
self._indices = list(iter(sampler))

def __len__(self):
return len(self._sampler)

def __getitem__(self, idx):
return self._dataset[self._indices[idx]]

class ArrayDataset(Dataset):
"""A dataset that combines multiple dataset-like objects, e.g.
Datasets, lists, arrays, etc.
Expand Down
23 changes: 22 additions & 1 deletion python/mxnet/gluon/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# coding: utf-8
# pylint: disable=
"""Dataset sampler."""
__all__ = ['Sampler', 'SequentialSampler', 'RandomSampler', 'BatchSampler']
__all__ = ['Sampler', 'SequentialSampler', 'RandomSampler', 'FilterSampler', 'BatchSampler']

import numpy as np

Expand Down Expand Up @@ -72,6 +72,27 @@ def __iter__(self):
def __len__(self):
return self._length

class FilterSampler(Sampler):
"""Samples elements from a Dataset for which `fn` returns True.
Parameters
----------
fn : callable
A callable function that takes a sample and returns a boolean
dataset : Dataset
The dataset to filter.
"""
def __init__(self, fn, dataset):
self._fn = fn
self._dataset = dataset
self._indices = [i for i, sample in enumerate(dataset) if fn(sample)]

def __iter__(self):
return iter(self._indices)

def __len__(self):
return len(self._indices)


class BatchSampler(Sampler):
"""Wraps over another `Sampler` and return mini-batches of samples.
Expand Down
39 changes: 39 additions & 0 deletions tests/python/unittest/test_gluon_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,45 @@ def test_dataloader_context():
def batchify(a):
return a

def test_dataset_filter():
length = 100
a = mx.gluon.data.SimpleDataset([i for i in range(length)])
a_filtered = a.filter(lambda x: x % 10 == 0)
assert(len(a_filtered) == 10)
for idx, sample in enumerate(a_filtered):
assert sample % 10 == 0
a_xform_filtered = a.transform(lambda x: x + 1).filter(lambda x: x % 10 == 0)
assert(len(a_xform_filtered) == 10)
# the filtered data is already transformed
for idx, sample in enumerate(a_xform_filtered):
assert sample % 10 == 0

def test_dataset_take():
length = 100
a = mx.gluon.data.SimpleDataset([i for i in range(length)])
a_take_full = a.take(1000)
assert len(a_take_full) == length
a_take_full = a.take(None)
assert len(a_take_full) == length
count = 10
a_take_10 = a.take(count)
assert len(a_take_10) == count
expected_total = sum([i for i in range(count)])
total = 0
for idx, sample in enumerate(a_take_10):
assert sample < count
total += sample
assert total == expected_total

a_xform_take_10 = a.transform(lambda x: x * 10).take(count)
assert len(a_xform_take_10) == count
expected_total = sum([i * 10 for i in range(count)])
total = 0
for idx, sample in enumerate(a_xform_take_10):
assert sample < count * 10
total += sample
assert total == expected_total

def test_dataloader_scope():
"""
Bug: Gluon DataLoader terminates the process pool early while
Expand Down

0 comments on commit 74d69cd

Please sign in to comment.