Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Dataset] Add take, filter API to dataset #16078

Merged
merged 7 commits into from
Sep 8, 2019
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions python/mxnet/gluon/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,41 @@ def __getitem__(self, idx):
def __len__(self):
raise NotImplementedError

def filter(self, fn):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will eagerly retrieve all values from a lazy=True transformed dataset. That should be documented?
(The transformation function will also be applied again in __getitem__)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I'll document that

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it better in some cases that lazy evaluation should be applied after filter, i.e.,

data.filter(xxx).transform is way better than data.transform.filter in some situations.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both are valid operations and I don't think we can stop users from doing transform() before filter(), since sometimes we need to actually filter the data based on the transformed data. For now I'd resort to better documentation

"""Returns a new dataset with samples filtered by the
filter function `fn`.

Parameters
----------
fn : callable
A filter function that takes a sample as input and
returns a boolean. Samples that return False are discarded.

Returns
-------
Dataset
The filtered dataset.
"""
return _FilteredDataset(self, fn)

def take(self, count):
"""Returns a new dataset with at most `count` number of samples in it.

Parameters
----------
count : int or None
A integer representing the number of elements of this dataset that
should be taken to form the new dataset. If count is None, or if count
is greater than the size of this dataset, the new dataset will contain
all elements of this dataset.

Returns
-------
Dataset
The result dataset.
"""
return _TakenDataset(self, count)

def transform(self, fn, lazy=True):
"""Returns a new dataset with each sample transformed by the
transformer function `fn`.
Expand Down Expand Up @@ -135,6 +170,38 @@ def __call__(self, x, *args):
return (self._fn(x),) + args
return self._fn(x)

class _FilteredDataset(Dataset):
"""Dataset with a filter applied"""
def __init__(self, dataset, fn):
self._dataset = dataset
self._indices = []
eric-haibin-lin marked this conversation as resolved.
Show resolved Hide resolved
for i, sample in enumerate(dataset):
if fn(sample):
self._indices.append(i)

def __len__(self):
return len(self._indices)

def __getitem__(self, idx):
return self._dataset[self._indices[idx]]

class _TakenDataset(Dataset):
"""Dataset with at most `count` samples"""
def __init__(self, dataset, count):
self._dataset = dataset
self._count = count
if count is None or count > len(dataset):
self._count = len(dataset)
assert self._count >= 0

def __len__(self):
return self._count

def __getitem__(self, idx):
if idx >= len(self):
raise IndexError('Invalid index %d. Expected index < %d'%(idx, len(self)))
return self._dataset[idx]

class ArrayDataset(Dataset):
"""A dataset that combines multiple dataset-like objects, e.g.
Datasets, lists, arrays, etc.
Expand Down
39 changes: 39 additions & 0 deletions tests/python/unittest/test_gluon_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,45 @@ def test_dataloader_context():
def batchify(a):
return a

def test_dataset_filter():
length = 100
a = mx.gluon.data.SimpleDataset([i for i in range(length)])
a_filtered = a.filter(lambda x: x % 10 == 0)
assert(len(a_filtered) == 10)
for idx, sample in enumerate(a_filtered):
assert sample % 10 == 0
a_xform_filtered = a.transform(lambda x: x + 1).filter(lambda x: x % 10 == 0)
assert(len(a_xform_filtered) == 10)
# the filtered data is already transformed
for idx, sample in enumerate(a_xform_filtered):
assert sample % 10 == 0

def test_dataset_take():
length = 100
a = mx.gluon.data.SimpleDataset([i for i in range(length)])
a_take_full = a.take(1000)
assert len(a_take_full) == length
a_take_full = a.take(None)
assert len(a_take_full) == length
count = 10
a_take_10 = a.take(count)
assert len(a_take_10) == count
expected_total = sum([i for i in range(count)])
total = 0
for idx, sample in enumerate(a_take_10):
assert sample < count
total += sample
assert total == expected_total

a_xform_take_10 = a.transform(lambda x: x * 10).take(count)
assert len(a_xform_take_10) == count
expected_total = sum([i * 10 for i in range(count)])
total = 0
for idx, sample in enumerate(a_xform_take_10):
assert sample < count * 10
total += sample
assert total == expected_total

def test_dataloader_scope():
"""
Bug: Gluon DataLoader terminates the process pool early while
Expand Down