From 5af0bb7b17b6be3db68c60c252b8ff0cf98748ee Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sun, 15 Sep 2019 17:40:02 +0000 Subject: [PATCH 1/4] add shard API --- python/mxnet/gluon/data/dataset.py | 31 ++++++++++++++++++++++++ python/mxnet/gluon/data/sampler.py | 10 +++++--- tests/python/unittest/test_gluon_data.py | 13 ++++++++++ 3 files changed, 50 insertions(+), 4 deletions(-) diff --git a/python/mxnet/gluon/data/dataset.py b/python/mxnet/gluon/data/dataset.py index a5399f259435..6111d6179294 100644 --- a/python/mxnet/gluon/data/dataset.py +++ b/python/mxnet/gluon/data/dataset.py @@ -64,6 +64,37 @@ def filter(self, fn): from . import FilterSampler return _SampledDataset(self, FilterSampler(fn, self)) + def shard(self, num_shards, index): + """Returns a new dataset includes only 1/num_shards of this dataset. + + For distributed training, be sure to shard before you randomize the dataset + (such as shuffle), if you want each worker to reach a unique subset. + + Parameters + ---------- + num_shards : int + A integer representing the number of data shards. + index : int + A integer representing the index of the current shard. + + Returns + ------- + Dataset + The result dataset. + """ + assert index < num_shards, 'Shard index of out bound: %d out of %d'%(index, num_shards) + assert num_shards > 0, 'Number of shards must be greater than 0' + length = len(self) + shard_len = length // num_shards + # Compute the start index for this partition + start = shard_len * index + # Compute the end index for this partition + end = start + shard_len + if index == num_shards - 1: + end = length + from . import SequentialSampler + return _SampledDataset(self, SequentialSampler(end - start, start)) + def take(self, count): """Returns a new dataset with at most `count` number of samples in it. diff --git a/python/mxnet/gluon/data/sampler.py b/python/mxnet/gluon/data/sampler.py index 980c63134fac..f53e636317df 100644 --- a/python/mxnet/gluon/data/sampler.py +++ b/python/mxnet/gluon/data/sampler.py @@ -36,23 +36,25 @@ def __len__(self): class SequentialSampler(Sampler): - """Samples elements from [0, length) sequentially. + """Samples elements from [start, start+length) sequentially. Parameters ---------- length : int Length of the sequence. + start : int, default is 0 + The start of the sequence index. """ - def __init__(self, length): + def __init__(self, length, start=0): self._length = length + self._start = start def __iter__(self): - return iter(range(self._length)) + return iter(range(self._start, self._start + self._length)) def __len__(self): return self._length - class RandomSampler(Sampler): """Samples elements from [0, length) randomly without replacement. diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py index 62bf7c3db869..f850dc3970fb 100644 --- a/tests/python/unittest/test_gluon_data.py +++ b/tests/python/unittest/test_gluon_data.py @@ -296,6 +296,19 @@ def test_dataset_filter(): for idx, sample in enumerate(a_xform_filtered): assert sample % 10 == 0 +def test_dataset_shard(): + length = 11 + a = mx.gluon.data.SimpleDataset([i for i in range(length)]) + a_shard_0 = a.shard(2, 0) + a_shard_1 = a.shard(2, 1) + assert len(a_shard_0) + len(a_shard_1) == length + total = 0 + for idx, sample in enumerate(a_shard_0): + total += sample + for idx, sample in enumerate(a_shard_1): + total += sample + assert total == sum(a) + def test_dataset_take(): length = 100 a = mx.gluon.data.SimpleDataset([i for i in range(length)]) From 3514ffcd7d39e32f74988f35808ca19438a8e2d5 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 16 Sep 2019 03:13:53 +0000 Subject: [PATCH 2/4] more balanced --- python/mxnet/gluon/data/dataset.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/mxnet/gluon/data/dataset.py b/python/mxnet/gluon/data/dataset.py index 6111d6179294..f8d54681270a 100644 --- a/python/mxnet/gluon/data/dataset.py +++ b/python/mxnet/gluon/data/dataset.py @@ -84,14 +84,13 @@ def shard(self, num_shards, index): """ assert index < num_shards, 'Shard index of out bound: %d out of %d'%(index, num_shards) assert num_shards > 0, 'Number of shards must be greater than 0' + assert index >= 0, 'Index must be non-negative' length = len(self) - shard_len = length // num_shards + shard_len = (length + num_shards - 1) // num_shards # Compute the start index for this partition start = shard_len * index # Compute the end index for this partition - end = start + shard_len - if index == num_shards - 1: - end = length + end = start + shard_len if index < num_shards - 1 else length from . import SequentialSampler return _SampledDataset(self, SequentialSampler(end - start, start)) From abd363ccf8412f6bb4619a0f7e66e4c18061241d Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 16 Sep 2019 04:57:59 +0000 Subject: [PATCH 3/4] fix a bug --- python/mxnet/gluon/data/dataset.py | 7 ++++--- tests/python/unittest/test_gluon_data.py | 21 +++++++++++++-------- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/python/mxnet/gluon/data/dataset.py b/python/mxnet/gluon/data/dataset.py index f8d54681270a..7f2c6342d595 100644 --- a/python/mxnet/gluon/data/dataset.py +++ b/python/mxnet/gluon/data/dataset.py @@ -86,11 +86,12 @@ def shard(self, num_shards, index): assert num_shards > 0, 'Number of shards must be greater than 0' assert index >= 0, 'Index must be non-negative' length = len(self) - shard_len = (length + num_shards - 1) // num_shards + shard_len = length // num_shards + rest = length % num_shards # Compute the start index for this partition - start = shard_len * index + start = shard_len * index + min(index, rest) # Compute the end index for this partition - end = start + shard_len if index < num_shards - 1 else length + end = start + shard_len + (index < rest) from . import SequentialSampler return _SampledDataset(self, SequentialSampler(end - start, start)) diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py index f850dc3970fb..9bbe6d0383b9 100644 --- a/tests/python/unittest/test_gluon_data.py +++ b/tests/python/unittest/test_gluon_data.py @@ -297,16 +297,21 @@ def test_dataset_filter(): assert sample % 10 == 0 def test_dataset_shard(): - length = 11 + length = 9 a = mx.gluon.data.SimpleDataset([i for i in range(length)]) - a_shard_0 = a.shard(2, 0) - a_shard_1 = a.shard(2, 1) - assert len(a_shard_0) + len(a_shard_1) == length + shard_0 = a.shard(4, 0) + shard_1 = a.shard(4, 1) + shard_2 = a.shard(4, 2) + shard_3 = a.shard(4, 3) + assert len(shard_0) + len(shard_1) + len(shard_2) + len(shard_3) == length + assert len(shard_0) > 0 + assert len(shard_1) > 0 + assert len(shard_2) > 0 + assert len(shard_3) > 0 total = 0 - for idx, sample in enumerate(a_shard_0): - total += sample - for idx, sample in enumerate(a_shard_1): - total += sample + for shard in [shard_0, shard_1, shard_2, shard_3]: + for idx, sample in enumerate(shard): + total += sample assert total == sum(a) def test_dataset_take(): From 796d693972bafa5b5770fe9ef9a19a1c2c447640 Mon Sep 17 00:00:00 2001 From: Haibin Lin Date: Tue, 17 Sep 2019 15:50:15 -0700 Subject: [PATCH 4/4] Update test_gluon_data.py --- tests/python/unittest/test_gluon_data.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/python/unittest/test_gluon_data.py b/tests/python/unittest/test_gluon_data.py index 9bbe6d0383b9..80bea2a511f2 100644 --- a/tests/python/unittest/test_gluon_data.py +++ b/tests/python/unittest/test_gluon_data.py @@ -304,10 +304,10 @@ def test_dataset_shard(): shard_2 = a.shard(4, 2) shard_3 = a.shard(4, 3) assert len(shard_0) + len(shard_1) + len(shard_2) + len(shard_3) == length - assert len(shard_0) > 0 - assert len(shard_1) > 0 - assert len(shard_2) > 0 - assert len(shard_3) > 0 + assert len(shard_0) == 3 + assert len(shard_1) == 2 + assert len(shard_2) == 2 + assert len(shard_3) == 2 total = 0 for shard in [shard_0, shard_1, shard_2, shard_3]: for idx, sample in enumerate(shard):