Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[Dataset] add shard API (#16175)
Browse files Browse the repository at this point in the history
* add shard API

* more balanced

* fix a bug

* Update test_gluon_data.py
  • Loading branch information
eric-haibin-lin committed Sep 18, 2019
1 parent 956cfa3 commit f75d093
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 4 deletions.
31 changes: 31 additions & 0 deletions python/mxnet/gluon/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,37 @@ def filter(self, fn):
from . import FilterSampler
return _SampledDataset(self, FilterSampler(fn, self))

def shard(self, num_shards, index):
"""Returns a new dataset includes only 1/num_shards of this dataset.
For distributed training, be sure to shard before you randomize the dataset
(such as shuffle), if you want each worker to reach a unique subset.
Parameters
----------
num_shards : int
A integer representing the number of data shards.
index : int
A integer representing the index of the current shard.
Returns
-------
Dataset
The result dataset.
"""
assert index < num_shards, 'Shard index of out bound: %d out of %d'%(index, num_shards)
assert num_shards > 0, 'Number of shards must be greater than 0'
assert index >= 0, 'Index must be non-negative'
length = len(self)
shard_len = length // num_shards
rest = length % num_shards
# Compute the start index for this partition
start = shard_len * index + min(index, rest)
# Compute the end index for this partition
end = start + shard_len + (index < rest)
from . import SequentialSampler
return _SampledDataset(self, SequentialSampler(end - start, start))

def take(self, count):
"""Returns a new dataset with at most `count` number of samples in it.
Expand Down
10 changes: 6 additions & 4 deletions python/mxnet/gluon/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,23 +36,25 @@ def __len__(self):


class SequentialSampler(Sampler):
"""Samples elements from [0, length) sequentially.
"""Samples elements from [start, start+length) sequentially.
Parameters
----------
length : int
Length of the sequence.
start : int, default is 0
The start of the sequence index.
"""
def __init__(self, length):
def __init__(self, length, start=0):
self._length = length
self._start = start

def __iter__(self):
return iter(range(self._length))
return iter(range(self._start, self._start + self._length))

def __len__(self):
return self._length


class RandomSampler(Sampler):
"""Samples elements from [0, length) randomly without replacement.
Expand Down
18 changes: 18 additions & 0 deletions tests/python/unittest/test_gluon_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,24 @@ def test_dataset_filter():
for idx, sample in enumerate(a_xform_filtered):
assert sample % 10 == 0

def test_dataset_shard():
length = 9
a = mx.gluon.data.SimpleDataset([i for i in range(length)])
shard_0 = a.shard(4, 0)
shard_1 = a.shard(4, 1)
shard_2 = a.shard(4, 2)
shard_3 = a.shard(4, 3)
assert len(shard_0) + len(shard_1) + len(shard_2) + len(shard_3) == length
assert len(shard_0) == 3
assert len(shard_1) == 2
assert len(shard_2) == 2
assert len(shard_3) == 2
total = 0
for shard in [shard_0, shard_1, shard_2, shard_3]:
for idx, sample in enumerate(shard):
total += sample
assert total == sum(a)

def test_dataset_take():
length = 100
a = mx.gluon.data.SimpleDataset([i for i in range(length)])
Expand Down

0 comments on commit f75d093

Please sign in to comment.