diff --git a/python/mxnet/gluon/data/dataset.py b/python/mxnet/gluon/data/dataset.py index a5399f259435..7f2c6342d595 100644 --- a/python/mxnet/gluon/data/dataset.py +++ b/python/mxnet/gluon/data/dataset.py @@ -64,6 +64,37 @@ def filter(self, fn): from . import FilterSampler return _SampledDataset(self, FilterSampler(fn, self)) + def shard(self, num_shards, index): + """Returns a new dataset includes only 1/num_shards of this dataset. + + For distributed training, be sure to shard before you randomize the dataset + (such as shuffle), if you want each worker to reach a unique subset. + + Parameters + ---------- + num_shards : int + A integer representing the number of data shards. + index : int + A integer representing the index of the current shard. + + Returns + ------- + Dataset + The result dataset. + """ + assert index < num_shards, 'Shard index of out bound: %d out of %d'%(index, num_shards) + assert num_shards > 0, 'Number of shards must be greater than 0' + assert index >= 0, 'Index must be non-negative' + length = len(self) + shard_len = length // num_shards + rest = length % num_shards + # Compute the start index for this partition + start = shard_len * index + min(index, rest) + # Compute the end index for this partition + end = start + shard_len + (index < rest) + from . import SequentialSampler + return _SampledDataset(self, SequentialSampler(end - start, start)) + def take(self, count): """Returns a new dataset with at most `count` number of samples in it. diff --git a/python/mxnet/gluon/data/sampler.py b/python/mxnet/gluon/data/sampler.py index 980c63134fac..f53e636317df 100644 --- a/python/mxnet/gluon/data/sampler.py +++ b/python/mxnet/gluon/data/sampler.py @@ -36,23 +36,25 @@ def __len__(self): class SequentialSampler(Sampler): - """Samples elements from [0, length) sequentially. + """Samples elements from [start, start+length) sequentially. Parameters ---------- length : int Length of the sequence. + start : int, default is 0 + The start of the sequence index. """ - def __init__(self, length): + def __init__(self, length, start=0): self._length = length + self._start = start def __iter__(self): - return iter(range(self._length)) + return iter(range(self._start, self._start + self._length)) def __len__(self): return self._length - class RandomSampler(Sampler): """Samples elements from [0, length) randomly without replacement. diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py index 62bf7c3db869..80bea2a511f2 100644 --- a/tests/python/unittest/test_gluon_data.py +++ b/tests/python/unittest/test_gluon_data.py @@ -296,6 +296,24 @@ def test_dataset_filter(): for idx, sample in enumerate(a_xform_filtered): assert sample % 10 == 0 +def test_dataset_shard(): + length = 9 + a = mx.gluon.data.SimpleDataset([i for i in range(length)]) + shard_0 = a.shard(4, 0) + shard_1 = a.shard(4, 1) + shard_2 = a.shard(4, 2) + shard_3 = a.shard(4, 3) + assert len(shard_0) + len(shard_1) + len(shard_2) + len(shard_3) == length + assert len(shard_0) == 3 + assert len(shard_1) == 2 + assert len(shard_2) == 2 + assert len(shard_3) == 2 + total = 0 + for shard in [shard_0, shard_1, shard_2, shard_3]: + for idx, sample in enumerate(shard): + total += sample + assert total == sum(a) + def test_dataset_take(): length = 100 a = mx.gluon.data.SimpleDataset([i for i in range(length)])