Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Dataset] add shard API #16175

Merged
merged 4 commits into from
Sep 18, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions python/mxnet/gluon/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,36 @@ def filter(self, fn):
from . import FilterSampler
return _SampledDataset(self, FilterSampler(fn, self))

def shard(self, num_shards, index):
"""Returns a new dataset includes only 1/num_shards of this dataset.

For distributed training, be sure to shard before you randomize the dataset
(such as shuffle), if you want each worker to reach a unique subset.

Parameters
----------
num_shards : int
A integer representing the number of data shards.
index : int
A integer representing the index of the current shard.

Returns
-------
Dataset
The result dataset.
"""
eric-haibin-lin marked this conversation as resolved.
Show resolved Hide resolved
assert index < num_shards, 'Shard index of out bound: %d out of %d'%(index, num_shards)
assert num_shards > 0, 'Number of shards must be greater than 0'
assert index >= 0, 'Index must be non-negative'
length = len(self)
shard_len = (length + num_shards - 1) // num_shards
# Compute the start index for this partition
start = shard_len * index
# Compute the end index for this partition
end = start + shard_len if index < num_shards - 1 else length
eric-haibin-lin marked this conversation as resolved.
Show resolved Hide resolved
from . import SequentialSampler
return _SampledDataset(self, SequentialSampler(end - start, start))

def take(self, count):
"""Returns a new dataset with at most `count` number of samples in it.

Expand Down
10 changes: 6 additions & 4 deletions python/mxnet/gluon/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,25 @@ def __len__(self):


class SequentialSampler(Sampler):
"""Samples elements from [0, length) sequentially.
"""Samples elements from [start, start+length) sequentially.

Parameters
----------
length : int
Length of the sequence.
start : int, default is 0
The start of the sequence index.
"""
def __init__(self, length):
def __init__(self, length, start=0):
self._length = length
self._start = start

def __iter__(self):
return iter(range(self._length))
return iter(range(self._start, self._start + self._length))

def __len__(self):
return self._length


class RandomSampler(Sampler):
"""Samples elements from [0, length) randomly without replacement.

Expand Down
13 changes: 13 additions & 0 deletions tests/python/unittest/test_gluon_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,19 @@ def test_dataset_filter():
for idx, sample in enumerate(a_xform_filtered):
assert sample % 10 == 0

def test_dataset_shard():
length = 11
a = mx.gluon.data.SimpleDataset([i for i in range(length)])
a_shard_0 = a.shard(2, 0)
a_shard_1 = a.shard(2, 1)
assert len(a_shard_0) + len(a_shard_1) == length
total = 0
for idx, sample in enumerate(a_shard_0):
total += sample
for idx, sample in enumerate(a_shard_1):
total += sample
assert total == sum(a)

def test_dataset_take():
length = 100
a = mx.gluon.data.SimpleDataset([i for i in range(length)])
Expand Down