From ce4ae1ce22ec58657bb78dbe28482a547082a7ae Mon Sep 17 00:00:00 2001 From: stu1130 Date: Mon, 20 Aug 2018 15:26:50 -0700 Subject: [PATCH 01/18] 1. move the shuffle to the reset 2. modify the roll_over behavior accordingly --- python/mxnet/io.py | 155 ++++++++++++++++++++++++++++++--------------- 1 file changed, 105 insertions(+), 50 deletions(-) diff --git a/python/mxnet/io.py b/python/mxnet/io.py index 884e9294741a..cf739692f50b 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -38,7 +38,7 @@ from .ndarray.sparse import array as sparse_array from .ndarray import _ndarray_cls from .ndarray import array -from .ndarray import concatenate +from .ndarray import concat from .ndarray import arange from .ndarray.random import shuffle as random_shuffle @@ -635,6 +635,9 @@ class NDArrayIter(DataIter): How to handle the last batch. This parameter can be 'pad', 'discard' or 'roll_over'. 'roll_over' is intended for training and can cause problems if used for prediction. + If 'pad', the last batch will be padded with data starting from the begining + If 'discard', the last batch will be discarded + If 'roll_over', the remaining elements will be rolled over to the next iteration data_name : str, optional The data name. label_name : str, optional @@ -653,28 +656,23 @@ def __init__(self, data, label=None, batch_size=1, shuffle=False, raise NotImplementedError("`NDArrayIter` only supports ``CSRNDArray``" \ " with `last_batch_handle` set to `discard`.") - # shuffle data - if shuffle: - tmp_idx = arange(self.data[0][1].shape[0], dtype=np.int32) - self.idx = random_shuffle(tmp_idx, out=tmp_idx).asnumpy() - self.data = _shuffle(self.data, self.idx) - self.label = _shuffle(self.label, self.idx) - else: - self.idx = np.arange(self.data[0][1].shape[0]) - - # batching + self.idx = np.arange(self.data[0][1].shape[0]) + self.shuffle = shuffle + self.last_batch_handle = last_batch_handle + self.batch_size = batch_size + self.cursor = -self.batch_size + # shuffle + self.reset() + # discard data with option discard if last_batch_handle == 'discard': - new_n = self.data[0][1].shape[0] - self.data[0][1].shape[0] % batch_size - self.idx = self.idx[:new_n] + self._discard_data() self.data_list = [x[1] for x in self.data] + [x[1] for x in self.label] self.num_source = len(self.data_list) self.num_data = self.idx.shape[0] - assert self.num_data >= batch_size, \ - "batch_size needs to be smaller than data size." - self.cursor = -batch_size - self.batch_size = batch_size - self.last_batch_handle = last_batch_handle + + self._cache_data = None + self._cache_label = None @property def provide_data(self): @@ -697,8 +695,16 @@ def hard_reset(self): self.cursor = -self.batch_size def reset(self): - if self.last_batch_handle == 'roll_over' and self.cursor > self.num_data: - self.cursor = -self.batch_size + (self.cursor%self.num_data)%self.batch_size + if self.shuffle: + self._shuffle() + if self.last_batch_handle == 'discard': + self._discard_data() + # last_batch_cursor = self.data[0][1].shape[0] - self.data[0][1].shape[0] % self.batch_size + if self.last_batch_handle == 'roll_over' and \ + hasattr(self, 'num_data') and \ + self.cursor > self.num_data - self.batch_size and \ + self.cursor < self.num_data: + self.cursor = self.cursor - self.num_data - self.batch_size else: self.cursor = -self.batch_size @@ -708,52 +714,92 @@ def iter_next(self): def next(self): if self.iter_next(): - return DataBatch(data=self.getdata(), label=self.getlabel(), \ + data = self.getdata() + label = self.getlabel() + if data[0].shape[0] != self.batch_size: + self._cache_data = data + self._cache_label = label + raise StopIteration + return DataBatch(data=data, label=label, \ pad=self.getpad(), index=None) else: raise StopIteration - - def _getdata(self, data_source): - """Load data from underlying arrays, internal use only.""" - assert(self.cursor < self.num_data), "DataIter needs reset." - if self.cursor + self.batch_size <= self.num_data: + + def _getdata(self, data_source, need_concat=False): + if not need_concat: + if self.cursor + self.batch_size < self.num_data: + end_idx = self.cursor + self.batch_size + else: + end_idx = self.num_data return [ # np.ndarray or NDArray case - x[1][self.cursor:self.cursor + self.batch_size] + x[1][self.cursor:end_idx] if isinstance(x[1], (np.ndarray, NDArray)) else # h5py (only supports indices in increasing order) array(x[1][sorted(self.idx[ - self.cursor:self.cursor + self.batch_size])][[ + self.cursor:end_idx])][[ list(self.idx[self.cursor: - self.cursor + self.batch_size]).index(i) + end_idx]).index(i) for i in sorted(self.idx[ - self.cursor:self.cursor + self.batch_size]) + self.cursor:end_idx]) ]]) for x in data_source ] else: - pad = self.batch_size - self.num_data + self.cursor - return [ - # np.ndarray or NDArray case - concatenate([x[1][self.cursor:], x[1][:pad]]) - if isinstance(x[1], (np.ndarray, NDArray)) else - # h5py (only supports indices in increasing order) - concatenate([ - array(x[1][sorted(self.idx[self.cursor:])][[ - list(self.idx[self.cursor:]).index(i) - for i in sorted(self.idx[self.cursor:]) - ]]), - array(x[1][sorted(self.idx[:pad])][[ - list(self.idx[:pad]).index(i) - for i in sorted(self.idx[:pad]) - ]]) - ]) for x in data_source - ] - + if self.last_batch_handle == 'roll_over': + assert self._cache_data is not None or self._cache_label is not None, \ + 'next epoch should have cached data' + cache_data = self._cache_data if self._cache_data is not None else self._cache_label + if isinstance(data_source[0][1], (np.ndarray, NDArray)): + data = [x[1][:self.cursor + self.batch_size] for x in data_source] + else: + data = [array(x[1][sorted(self.idx[:self.cursor + self.batch_size])][[ + list(self.idx[:self.cursor + self.batch_size]).index(i) + for i in sorted(self.idx[:self.cursor + self.batch_size]) + ]]) for x in data_source] + + data = concat(cache_data[0], data[0], dim=0) + if self._cache_data is not None: + self._cache_data = None + else: + self._cache_label = None + return [data] + else: + pad = self.batch_size - self.num_data + self.cursor + return [ + # np.ndarray or NDArray case + concat(x[1][self.cursor:], x[1][:pad], dim=0) + if isinstance(x[1], (np.ndarray, NDArray)) else + # h5py (only supports indices in increasing order) + concat( + array(x[1][sorted(self.idx[self.cursor:])][[ + list(self.idx[self.cursor:]).index(i) + for i in sorted(self.idx[self.cursor:]) + ]]), + array(x[1][sorted(self.idx[:pad])][[ + list(self.idx[:pad]).index(i) + for i in sorted(self.idx[:pad]) + ]]), dim=0 + ) for x in data_source + ] + + def _batchify(self, data_source): + """Load data from underlying arrays, internal use only.""" + assert self.cursor < self.num_data, "DataIter needs reset." + if self.last_batch_handle == 'roll_over' and \ + self.cursor < 0 and \ + self.cursor > -self.batch_size: + return self._getdata(data_source, True) + elif self.last_batch_handle == 'pad' and \ + self.cursor + self.batch_size > self.num_data: + return self._getdata(data_source, True) + else: + return self._getdata(data_source) + def getdata(self): - return self._getdata(self.data) + return self._batchify(self.data) def getlabel(self): - return self._getdata(self.label) + return self._batchify(self.label) def getpad(self): if self.last_batch_handle == 'pad' and \ @@ -762,6 +808,15 @@ def getpad(self): else: return 0 + def _shuffle(self): + np.random.shuffle(self.idx) + self.data = _shuffle(self.data, self.idx) + self.label = _shuffle(self.label, self.idx) + + def _discard_data(self): + new_n = self.data[0][1].shape[0] - self.data[0][1].shape[0] % self.batch_size + self.idx = self.idx[:new_n] + self.num_data = self.idx.shape[0] class MXDataIter(DataIter): """A python wrapper a C++ data iterator. From 26a4c58af42c326f4432845b033053b07c2b6086 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Mon, 20 Aug 2018 16:24:57 -0700 Subject: [PATCH 02/18] refactor the concat part --- python/mxnet/io.py | 95 ++++++++++++++++++++++++++++++++-------------- 1 file changed, 67 insertions(+), 28 deletions(-) diff --git a/python/mxnet/io.py b/python/mxnet/io.py index cf739692f50b..fe232746dd52 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -745,42 +745,81 @@ def _getdata(self, data_source, need_concat=False): ]]) for x in data_source ] else: + # if self.last_batch_handle == 'roll_over': + # assert self._cache_data is not None or self._cache_label is not None, \ + # 'next epoch should have cached data' + # cache_data = self._cache_data if self._cache_data is not None else self._cache_label + # data = [ + # concat(cache_data[0], x[1][:self.cursor + self.batch_size], dim=0) + # if isinstance(x[1], (np.ndarray, NDArray)) else + # concat( + # cache_data[0], + # array(x[1][sorted(self.idx[:self.cursor + self.batch_size])][[ + # list(self.idx[:self.cursor + self.batch_size]).index(i) + # for i in sorted(self.idx[:self.cursor + self.batch_size]) + # ]]), dim=0 + # ) for x in data_source + # ] + # if self._cache_data is not None: + # self._cache_data = None + # else: + # self._cache_label = None + # return data + # else: + # pad = self.batch_size - self.num_data + self.cursor + # data = [ + # # np.ndarray or NDArray case + # concat(x[1][self.cursor:], x[1][:pad], dim=0) + # if isinstance(x[1], (np.ndarray, NDArray)) else + # # h5py (only supports indices in increasing order) + # concat( + # array(x[1][sorted(self.idx[self.cursor:])][[ + # list(self.idx[self.cursor:]).index(i) + # for i in sorted(self.idx[self.cursor:]) + # ]]), + # array(x[1][sorted(self.idx[:pad])][[ + # list(self.idx[:pad]).index(i) + # for i in sorted(self.idx[:pad]) + # ]]), dim=0 + # ) for x in data_source + # ] + # return data if self.last_batch_handle == 'roll_over': assert self._cache_data is not None or self._cache_label is not None, \ 'next epoch should have cached data' cache_data = self._cache_data if self._cache_data is not None else self._cache_label - if isinstance(data_source[0][1], (np.ndarray, NDArray)): - data = [x[1][:self.cursor + self.batch_size] for x in data_source] - else: - data = [array(x[1][sorted(self.idx[:self.cursor + self.batch_size])][[ - list(self.idx[:self.cursor + self.batch_size]).index(i) - for i in sorted(self.idx[:self.cursor + self.batch_size]) - ]]) for x in data_source] - - data = concat(cache_data[0], data[0], dim=0) + if self.last_batch_handle == 'roll_over': + first_data = cache_data + elif isinstance(data_source[0][1], (np.ndarray, NDArray)): + first_data = [x[1][self.cursor:] for x in data_source] + else: + first_data = [ + array(x[1][sorted(self.idx[self.cursor:])][[ + list(self.idx[self.cursor:]).index(i) + for i in sorted(self.idx[self.cursor:]) + ]]) for x in data_source] + if self.last_batch_handle == 'roll_over': + second_idx = self.cursor + self.batch_size + else: + # pad + second_idx = self.batch_size - self.num_data + self.cursor + data = [ + concat(first_data[0], x[1][:second_idx], dim=0) + if isinstance(x[1], (np.ndarray, NDArray)) else + concat( + first_data[0], + array(x[1][sorted(self.idx[:second_idx])][[ + list(self.idx[:second_idx]).index(i) + for i in sorted(self.idx[:second_idx]) + ]]), dim=0 + ) for x in data_source + ] + if self.last_batch_handle == 'roll_over': if self._cache_data is not None: self._cache_data = None else: self._cache_label = None - return [data] - else: - pad = self.batch_size - self.num_data + self.cursor - return [ - # np.ndarray or NDArray case - concat(x[1][self.cursor:], x[1][:pad], dim=0) - if isinstance(x[1], (np.ndarray, NDArray)) else - # h5py (only supports indices in increasing order) - concat( - array(x[1][sorted(self.idx[self.cursor:])][[ - list(self.idx[self.cursor:]).index(i) - for i in sorted(self.idx[self.cursor:]) - ]]), - array(x[1][sorted(self.idx[:pad])][[ - list(self.idx[:pad]).index(i) - for i in sorted(self.idx[:pad]) - ]]), dim=0 - ) for x in data_source - ] + return data def _batchify(self, data_source): """Load data from underlying arrays, internal use only.""" From 21d399b8c2dbc0734fd87b4cbcf0179123460743 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Tue, 21 Aug 2018 09:47:37 -0700 Subject: [PATCH 03/18] refactor the code --- python/mxnet/io.py | 57 ++++++++-------------------------------------- 1 file changed, 10 insertions(+), 47 deletions(-) diff --git a/python/mxnet/io.py b/python/mxnet/io.py index fe232746dd52..1e33b08629b2 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -701,9 +701,8 @@ def reset(self): self._discard_data() # last_batch_cursor = self.data[0][1].shape[0] - self.data[0][1].shape[0] % self.batch_size if self.last_batch_handle == 'roll_over' and \ - hasattr(self, 'num_data') and \ - self.cursor > self.num_data - self.batch_size and \ - self.cursor < self.num_data: + hasattr(self, 'num_data') and \ + self.num_data - self.batch_size < self.cursor < self.num_data: self.cursor = self.cursor - self.num_data - self.batch_size else: self.cursor = -self.batch_size @@ -724,7 +723,7 @@ def next(self): pad=self.getpad(), index=None) else: raise StopIteration - + def _getdata(self, data_source, need_concat=False): if not need_concat: if self.cursor + self.batch_size < self.num_data: @@ -745,45 +744,6 @@ def _getdata(self, data_source, need_concat=False): ]]) for x in data_source ] else: - # if self.last_batch_handle == 'roll_over': - # assert self._cache_data is not None or self._cache_label is not None, \ - # 'next epoch should have cached data' - # cache_data = self._cache_data if self._cache_data is not None else self._cache_label - # data = [ - # concat(cache_data[0], x[1][:self.cursor + self.batch_size], dim=0) - # if isinstance(x[1], (np.ndarray, NDArray)) else - # concat( - # cache_data[0], - # array(x[1][sorted(self.idx[:self.cursor + self.batch_size])][[ - # list(self.idx[:self.cursor + self.batch_size]).index(i) - # for i in sorted(self.idx[:self.cursor + self.batch_size]) - # ]]), dim=0 - # ) for x in data_source - # ] - # if self._cache_data is not None: - # self._cache_data = None - # else: - # self._cache_label = None - # return data - # else: - # pad = self.batch_size - self.num_data + self.cursor - # data = [ - # # np.ndarray or NDArray case - # concat(x[1][self.cursor:], x[1][:pad], dim=0) - # if isinstance(x[1], (np.ndarray, NDArray)) else - # # h5py (only supports indices in increasing order) - # concat( - # array(x[1][sorted(self.idx[self.cursor:])][[ - # list(self.idx[self.cursor:]).index(i) - # for i in sorted(self.idx[self.cursor:]) - # ]]), - # array(x[1][sorted(self.idx[:pad])][[ - # list(self.idx[:pad]).index(i) - # for i in sorted(self.idx[:pad]) - # ]]), dim=0 - # ) for x in data_source - # ] - # return data if self.last_batch_handle == 'roll_over': assert self._cache_data is not None or self._cache_label is not None, \ 'next epoch should have cached data' @@ -825,15 +785,14 @@ def _batchify(self, data_source): """Load data from underlying arrays, internal use only.""" assert self.cursor < self.num_data, "DataIter needs reset." if self.last_batch_handle == 'roll_over' and \ - self.cursor < 0 and \ - self.cursor > -self.batch_size: + -self.batch_size < self.cursor < 0: return self._getdata(data_source, True) elif self.last_batch_handle == 'pad' and \ self.cursor + self.batch_size > self.num_data: - return self._getdata(data_source, True) + return self._getdata(data_source, True) else: return self._getdata(data_source) - + def getdata(self): return self._batchify(self.data) @@ -844,6 +803,10 @@ def getpad(self): if self.last_batch_handle == 'pad' and \ self.cursor + self.batch_size > self.num_data: return self.cursor + self.batch_size - self.num_data + # check the first batch + elif self.last_batch_handle == 'roll_over' and \ + -self.batch_size < self.cursor < 0: + return -self.cursor else: return 0 From c14b4ef28ceefc65e6c785bb59069442d88cb29c Mon Sep 17 00:00:00 2001 From: stu1130 Date: Tue, 21 Aug 2018 09:48:02 -0700 Subject: [PATCH 04/18] implement unit test for last_batch_handle --- tests/python/unittest/test_io.py | 49 ++++++++++++++++---------------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index 4dfa69cc1050..ba4a5b238fd2 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -90,32 +90,33 @@ def test_Cifar10Rec(): def test_NDArrayIter(): data = np.ones([1000, 2, 2]) - label = np.ones([1000, 1]) + labels = np.ones([1000, 1]) for i in range(1000): data[i] = i / 100 - label[i] = i / 100 - dataiter = mx.io.NDArrayIter( - data, label, 128, True, last_batch_handle='pad') - batchidx = 0 - for batch in dataiter: - batchidx += 1 - assert(batchidx == 8) - dataiter = mx.io.NDArrayIter( - data, label, 128, False, last_batch_handle='pad') - batchidx = 0 - labelcount = [0 for i in range(10)] - for batch in dataiter: - label = batch.label[0].asnumpy().flatten() - assert((batch.data[0].asnumpy()[:, 0, 0] == label).all()) - for i in range(label.shape[0]): - labelcount[int(label[i])] += 1 - - for i in range(10): - if i == 0: - assert(labelcount[i] == 124) - else: - assert(labelcount[i] == 100) - + labels[i] = i / 100 + + idx = 0 + last_batch_handle_list = ['pad', 'discard' , 'roll_over'] + labelcount_list = [(124, 100), (100, 96), (100, 96)] + batch_count_list = [8, 7, 7] + + for idx in range(len(last_batch_handle_list)): + dataiter = mx.io.NDArrayIter( + data, labels, 128, False, last_batch_handle=last_batch_handle_list[idx]) + batch_count = 0 + labelcount = [0 for i in range(10)] + tmp = 0 + for batch in dataiter: + label = batch.label[0].asnumpy().flatten() + assert((batch.data[0].asnumpy()[:, 0, 0] == label).all()), last_batch_handle_list[idx] + for i in range(label.shape[0]): + labelcount[int(label[i])] += 1 + batch_count += 1 + # assert result + assert(labelcount[0] == labelcount_list[idx][0]), last_batch_handle_list[idx] + assert(labelcount[8] == labelcount_list[idx][1]), last_batch_handle_list[idx] + + assert batch_count == batch_count_list[idx] def test_NDArrayIter_h5py(): if not h5py: From e8fb56819cbf807c79385c4777eb252669dad293 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Tue, 21 Aug 2018 14:24:50 -0700 Subject: [PATCH 05/18] refactor the getdata part --- python/mxnet/io.py | 146 +++++++++++++++++++-------------------------- 1 file changed, 61 insertions(+), 85 deletions(-) diff --git a/python/mxnet/io.py b/python/mxnet/io.py index 1e33b08629b2..051703a0f2b0 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -661,16 +661,13 @@ def __init__(self, data, label=None, batch_size=1, shuffle=False, self.last_batch_handle = last_batch_handle self.batch_size = batch_size self.cursor = -self.batch_size + self.num_data = self.idx.shape[0] # shuffle self.reset() - # discard data with option discard - if last_batch_handle == 'discard': - self._discard_data() self.data_list = [x[1] for x in self.data] + [x[1] for x in self.label] self.num_source = len(self.data_list) - self.num_data = self.idx.shape[0] - + # used for 'roll_over' self._cache_data = None self._cache_label = None @@ -692,17 +689,21 @@ def provide_label(self): def hard_reset(self): """Ignore roll over data and set to start.""" + if self.shuffle: + self._shuffle() self.cursor = -self.batch_size + self._cache_data = None + self._cache_label = None def reset(self): if self.shuffle: self._shuffle() - if self.last_batch_handle == 'discard': - self._discard_data() - # last_batch_cursor = self.data[0][1].shape[0] - self.data[0][1].shape[0] % self.batch_size + # need to check if self.num_data exists + # the range below indicate the last batch if self.last_batch_handle == 'roll_over' and \ - hasattr(self, 'num_data') and \ - self.num_data - self.batch_size < self.cursor < self.num_data: + hasattr(self, 'num_data') and \ + self.num_data - self.batch_size < self.cursor < self.num_data: + # self.cursor - self.num_data represents the data we have for the last batch self.cursor = self.cursor - self.num_data - self.batch_size else: self.cursor = -self.batch_size @@ -712,86 +713,66 @@ def iter_next(self): return self.cursor < self.num_data def next(self): - if self.iter_next(): - data = self.getdata() - label = self.getlabel() - if data[0].shape[0] != self.batch_size: - self._cache_data = data - self._cache_label = label - raise StopIteration - return DataBatch(data=data, label=label, \ - pad=self.getpad(), index=None) - else: + if not self.iter_next(): raise StopIteration + data = self.getdata() + label = self.getlabel() + # iter should stop when last batch is not complete + if data[0].shape[0] != self.batch_size: + # in this case, cache it for next epoch + self._cache_data = data + self._cache_label = label + raise StopIteration + return DataBatch(data=data, label=label, \ + pad=self.getpad(), index=None) + + def _getdata(self, data_source, start=None, end=None): + assert start is not None or end is not None, 'should at least specify start or end' + start = start if start is not None else 0 + end = end if end is not None else data_source[0][1].shape[0] + s = slice(start, end) + return [ + x[1][s] + if isinstance(x[1], (np.ndarray, NDArray)) else + # h5py (only supports indices in increasing order) + array(x[1][sorted(self.idx[s])][[ + list(self.idx[s]).index(i) + for i in sorted(self.idx[s]) + ]]) for x in data_source + ] - def _getdata(self, data_source, need_concat=False): - if not need_concat: - if self.cursor + self.batch_size < self.num_data: - end_idx = self.cursor + self.batch_size - else: - end_idx = self.num_data - return [ - # np.ndarray or NDArray case - x[1][self.cursor:end_idx] - if isinstance(x[1], (np.ndarray, NDArray)) else - # h5py (only supports indices in increasing order) - array(x[1][sorted(self.idx[ - self.cursor:end_idx])][[ - list(self.idx[self.cursor: - end_idx]).index(i) - for i in sorted(self.idx[ - self.cursor:end_idx]) - ]]) for x in data_source - ] - else: - if self.last_batch_handle == 'roll_over': - assert self._cache_data is not None or self._cache_label is not None, \ - 'next epoch should have cached data' - cache_data = self._cache_data if self._cache_data is not None else self._cache_label - if self.last_batch_handle == 'roll_over': - first_data = cache_data - elif isinstance(data_source[0][1], (np.ndarray, NDArray)): - first_data = [x[1][self.cursor:] for x in data_source] - else: - first_data = [ - array(x[1][sorted(self.idx[self.cursor:])][[ - list(self.idx[self.cursor:]).index(i) - for i in sorted(self.idx[self.cursor:]) - ]]) for x in data_source] - if self.last_batch_handle == 'roll_over': - second_idx = self.cursor + self.batch_size - else: - # pad - second_idx = self.batch_size - self.num_data + self.cursor - data = [ - concat(first_data[0], x[1][:second_idx], dim=0) - if isinstance(x[1], (np.ndarray, NDArray)) else - concat( - first_data[0], - array(x[1][sorted(self.idx[:second_idx])][[ - list(self.idx[:second_idx]).index(i) - for i in sorted(self.idx[:second_idx]) - ]]), dim=0 - ) for x in data_source - ] - if self.last_batch_handle == 'roll_over': - if self._cache_data is not None: - self._cache_data = None - else: - self._cache_label = None - return data + def _concat(self, first_data, second_data): + return [ + concat(first_data[0], second_data[0], dim=0) + ] def _batchify(self, data_source): """Load data from underlying arrays, internal use only.""" - assert self.cursor < self.num_data, "DataIter needs reset." + assert self.cursor < self.num_data, 'DataIter needs reset.' if self.last_batch_handle == 'roll_over' and \ -self.batch_size < self.cursor < 0: - return self._getdata(data_source, True) + assert self._cache_data is not None or self._cache_label is not None, \ + 'next epoch should have cached data' + cache_data = self._cache_data if self._cache_data is not None else self._cache_label + second_data = self._getdata( + data_source, end=self.cursor + self.batch_size) + if self._cache_data is not None: + self._cache_data = None + else: + self._cache_label = None + return self._concat(cache_data, second_data) elif self.last_batch_handle == 'pad' and \ self.cursor + self.batch_size > self.num_data: - return self._getdata(data_source, True) + pad = self.batch_size - self.num_data + self.cursor + first_data = self._getdata(data_source, start=self.cursor) + second_data = self._getdata(data_source, end=pad) + return self._concat(first_data, second_data) else: - return self._getdata(data_source) + if self.cursor + self.batch_size < self.num_data: + end_idx = self.cursor + self.batch_size + else: + end_idx = self.num_data + return self._getdata(data_source, self.cursor, end_idx) def getdata(self): return self._batchify(self.data) @@ -815,11 +796,6 @@ def _shuffle(self): self.data = _shuffle(self.data, self.idx) self.label = _shuffle(self.label, self.idx) - def _discard_data(self): - new_n = self.data[0][1].shape[0] - self.data[0][1].shape[0] % self.batch_size - self.idx = self.idx[:new_n] - self.num_data = self.idx.shape[0] - class MXDataIter(DataIter): """A python wrapper a C++ data iterator. From 33290f60e42410b554daf3dc02adc25a19a20396 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Tue, 21 Aug 2018 15:34:52 -0700 Subject: [PATCH 06/18] add docstring and refine the code according to linter --- python/mxnet/io.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/python/mxnet/io.py b/python/mxnet/io.py index 051703a0f2b0..8bf43eaa66ca 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -39,8 +39,6 @@ from .ndarray import _ndarray_cls from .ndarray import array from .ndarray import concat -from .ndarray import arange -from .ndarray.random import shuffle as random_shuffle class DataDesc(namedtuple('DataDesc', ['name', 'shape'])): """DataDesc is used to store name, shape, type and layout @@ -696,23 +694,24 @@ def hard_reset(self): self._cache_label = None def reset(self): + """Resets the iterator to the beginning of the data.""" if self.shuffle: self._shuffle() - # need to check if self.num_data exists # the range below indicate the last batch if self.last_batch_handle == 'roll_over' and \ - hasattr(self, 'num_data') and \ self.num_data - self.batch_size < self.cursor < self.num_data: - # self.cursor - self.num_data represents the data we have for the last batch + # (self.cursor - self.num_data) represents the data we have for the last batch self.cursor = self.cursor - self.num_data - self.batch_size else: self.cursor = -self.batch_size def iter_next(self): + """Increments the coursor and check current cursor if exceed num of data.""" self.cursor += self.batch_size return self.cursor < self.num_data def next(self): + """Returns the next batch of data.""" if not self.iter_next(): raise StopIteration data = self.getdata() @@ -727,6 +726,7 @@ def next(self): pad=self.getpad(), index=None) def _getdata(self, data_source, start=None, end=None): + """Load data from underlying arrays.""" assert start is not None or end is not None, 'should at least specify start or end' start = start if start is not None else 0 end = end if end is not None else data_source[0][1].shape[0] @@ -742,6 +742,7 @@ def _getdata(self, data_source, start=None, end=None): ] def _concat(self, first_data, second_data): + """Helper function to concat two NDArrays.""" return [ concat(first_data[0], second_data[0], dim=0) ] @@ -749,6 +750,7 @@ def _concat(self, first_data, second_data): def _batchify(self, data_source): """Load data from underlying arrays, internal use only.""" assert self.cursor < self.num_data, 'DataIter needs reset.' + # first batch of next epoch with 'roll_over' if self.last_batch_handle == 'roll_over' and \ -self.batch_size < self.cursor < 0: assert self._cache_data is not None or self._cache_label is not None, \ @@ -761,26 +763,32 @@ def _batchify(self, data_source): else: self._cache_label = None return self._concat(cache_data, second_data) + # last batch with 'pad' elif self.last_batch_handle == 'pad' and \ self.cursor + self.batch_size > self.num_data: pad = self.batch_size - self.num_data + self.cursor first_data = self._getdata(data_source, start=self.cursor) second_data = self._getdata(data_source, end=pad) return self._concat(first_data, second_data) + # normal case else: if self.cursor + self.batch_size < self.num_data: end_idx = self.cursor + self.batch_size + # get incomplete last batch else: end_idx = self.num_data return self._getdata(data_source, self.cursor, end_idx) def getdata(self): + """Get data.""" return self._batchify(self.data) def getlabel(self): + """Get label.""" return self._batchify(self.label) def getpad(self): + """Get pad value of DataBatch.""" if self.last_batch_handle == 'pad' and \ self.cursor + self.batch_size > self.num_data: return self.cursor + self.batch_size - self.num_data @@ -792,6 +800,7 @@ def getpad(self): return 0 def _shuffle(self): + """Shuffle the data.""" np.random.shuffle(self.idx) self.data = _shuffle(self.data, self.idx) self.label = _shuffle(self.label, self.idx) From 3eb93c55cf59568bbe5626d1179deb4572e2aeed Mon Sep 17 00:00:00 2001 From: stu1130 Date: Tue, 21 Aug 2018 16:12:12 -0700 Subject: [PATCH 07/18] 1. add test case for NDArrayIter_h5py 2. refactor the implementation --- tests/python/unittest/test_io.py | 57 ++++++++++++-------------------- 1 file changed, 21 insertions(+), 36 deletions(-) diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index ba4a5b238fd2..3c14fa63bd16 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -87,14 +87,15 @@ def test_Cifar10Rec(): for i in range(10): assert(labelcount[i] == 5000) - -def test_NDArrayIter(): +def _init_NDArrayIter_data(): data = np.ones([1000, 2, 2]) labels = np.ones([1000, 1]) for i in range(1000): data[i] = i / 100 labels[i] = i / 100 - + return data, labels + +def _test_last_batch_handle(data, labels): idx = 0 last_batch_handle_list = ['pad', 'discard' , 'roll_over'] labelcount_list = [(124, 100), (100, 96), (100, 96)] @@ -105,7 +106,6 @@ def test_NDArrayIter(): data, labels, 128, False, last_batch_handle=last_batch_handle_list[idx]) batch_count = 0 labelcount = [0 for i in range(10)] - tmp = 0 for batch in dataiter: label = batch.label[0].asnumpy().flatten() assert((batch.data[0].asnumpy()[:, 0, 0] == label).all()), last_batch_handle_list[idx] @@ -117,53 +117,38 @@ def test_NDArrayIter(): assert(labelcount[8] == labelcount_list[idx][1]), last_batch_handle_list[idx] assert batch_count == batch_count_list[idx] + # shuffle equals True for sanity test + dataiter = mx.io.NDArrayIter( + data, labels, 128, True, last_batch_handle=last_batch_handle_list[idx]) + batch_count = 0 + for _ in dataiter: + batch_count += 1 + assert batch_count == batch_count_list[idx] + +def test_NDArrayIter(): + data, labels = _init_NDArrayIter_data() + _test_last_batch_handle(data, labels) def test_NDArrayIter_h5py(): if not h5py: return - data = np.ones([1000, 2, 2]) - label = np.ones([1000, 1]) - for i in range(1000): - data[i] = i / 100 - label[i] = i / 100 + data, labels = _init_NDArrayIter_data() try: - os.remove("ndarraytest.h5") + os.remove('ndarraytest.h5') except OSError: pass - with h5py.File("ndarraytest.h5") as f: - f.create_dataset("data", data=data) - f.create_dataset("label", data=label) - - dataiter = mx.io.NDArrayIter( - f["data"], f["label"], 128, True, last_batch_handle='pad') - batchidx = 0 - for batch in dataiter: - batchidx += 1 - assert(batchidx == 8) - - dataiter = mx.io.NDArrayIter( - f["data"], f["label"], 128, False, last_batch_handle='pad') - labelcount = [0 for i in range(10)] - for batch in dataiter: - label = batch.label[0].asnumpy().flatten() - assert((batch.data[0].asnumpy()[:, 0, 0] == label).all()) - for i in range(label.shape[0]): - labelcount[int(label[i])] += 1 + with h5py.File('ndarraytest.h5') as f: + f.create_dataset('data', data=data) + f.create_dataset('label', data=labels) + _test_last_batch_handle(f['data'], f['label']) try: os.remove("ndarraytest.h5") except OSError: pass - for i in range(10): - if i == 0: - assert(labelcount[i] == 124) - else: - assert(labelcount[i] == 100) - - def test_NDArrayIter_csr(): # creating toy data num_rows = rnd.randint(5, 15) From 4ace292c2df89e3439fd1683c526de6055e71667 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Tue, 21 Aug 2018 16:14:26 -0700 Subject: [PATCH 08/18] update contributions doc --- CONTRIBUTORS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 8d8aeaca73e4..1c005d57c4a6 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -178,3 +178,4 @@ List of Contributors * [Aaron Markham](https://github.com/aaronmarkham) * [Sam Skalicky](https://github.com/samskalicky) * [Per Goncalves da Silva](https://github.com/perdasilva) +* [Cheng-Che Lee](https://github.com/stu1130) From e939b069ed02eb7c64c8addceea9e4f54a928826 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Tue, 21 Aug 2018 16:53:43 -0700 Subject: [PATCH 09/18] fix wording --- python/mxnet/io.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/mxnet/io.py b/python/mxnet/io.py index 8bf43eaa66ca..c768fa817cd4 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -631,11 +631,11 @@ class NDArrayIter(DataIter): Only supported if no h5py.Dataset inputs are used. last_batch_handle : str, optional How to handle the last batch. This parameter can be 'pad', 'discard' or - 'roll_over'. 'roll_over' is intended for training and can cause problems - if used for prediction. + 'roll_over'. If 'pad', the last batch will be padded with data starting from the begining If 'discard', the last batch will be discarded - If 'roll_over', the remaining elements will be rolled over to the next iteration + If 'roll_over', the remaining elements will be rolled over to the next iteration and + note that it is intended for training and can cause problems if used for prediction. data_name : str, optional The data name. label_name : str, optional From 21ccec69c8ac41e7883d22383e7f2a1d9a5e99f1 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Wed, 22 Aug 2018 09:50:21 -0700 Subject: [PATCH 10/18] update doc for roll_over --- python/mxnet/io.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/python/mxnet/io.py b/python/mxnet/io.py index c768fa817cd4..ed4abda8c448 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -599,6 +599,22 @@ class NDArrayIter(DataIter): ... >>> batchidx # Remaining examples are discarded. So, 10/3 batches are created. 3 + >>> dataiter = mx.io.NDArrayIter(data, labels, 3, False, last_batch_handle='roll_over') + >>> batchidx = 0 + >>> for batch in dataiter: + ... batchidx += 1 + ... + >>> batchidx # Remaining examples are rolled over to the next iteration. + 3 + >>> dataiter.reset() + >>> dataiter.next().data[0].asnumpy() + [[[ 36. 37.] + [ 38. 39.]] + [[ 0. 1.] + [ 2. 3.]] + [[ 4. 5.] + [ 6. 7.]]] + (3L, 2L, 2L) `NDArrayIter` also supports multiple input and labels. From 353afe7c964c324c87d10af47f21bd959cd65f29 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Wed, 22 Aug 2018 15:42:08 -0700 Subject: [PATCH 11/18] 1. add test for second iteration of roll_over 2. add shuffle test case --- tests/python/unittest/test_io.py | 58 ++++++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 14 deletions(-) diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index 3c14fa63bd16..ae686261b818 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -87,6 +87,7 @@ def test_Cifar10Rec(): for i in range(10): assert(labelcount[i] == 5000) + def _init_NDArrayIter_data(): data = np.ones([1000, 2, 2]) labels = np.ones([1000, 1]) @@ -95,12 +96,13 @@ def _init_NDArrayIter_data(): labels[i] = i / 100 return data, labels + def _test_last_batch_handle(data, labels): - idx = 0 + # Test the three parameters 'pad', 'discard', 'roll_over' last_batch_handle_list = ['pad', 'discard' , 'roll_over'] labelcount_list = [(124, 100), (100, 96), (100, 96)] batch_count_list = [8, 7, 7] - + for idx in range(len(last_batch_handle_list)): dataiter = mx.io.NDArrayIter( data, labels, 128, False, last_batch_handle=last_batch_handle_list[idx]) @@ -108,26 +110,45 @@ def _test_last_batch_handle(data, labels): labelcount = [0 for i in range(10)] for batch in dataiter: label = batch.label[0].asnumpy().flatten() + # check data if it matches corresponding labels assert((batch.data[0].asnumpy()[:, 0, 0] == label).all()), last_batch_handle_list[idx] for i in range(label.shape[0]): labelcount[int(label[i])] += 1 + # keep the last batch of 'pad' to be used later + # to test first batch of roll_over in second iteration batch_count += 1 - # assert result - assert(labelcount[0] == labelcount_list[idx][0]), last_batch_handle_list[idx] - assert(labelcount[8] == labelcount_list[idx][1]), last_batch_handle_list[idx] - - assert batch_count == batch_count_list[idx] - # shuffle equals True for sanity test - dataiter = mx.io.NDArrayIter( - data, labels, 128, True, last_batch_handle=last_batch_handle_list[idx]) - batch_count = 0 - for _ in dataiter: - batch_count += 1 + if last_batch_handle_list[idx] == 'pad' and \ + batch_count == 8: + cache = batch.data[0].asnumpy() + # check if batchifying functionality work properly + assert labelcount[0] == labelcount_list[idx][0], last_batch_handle_list[idx] + assert labelcount[8] == labelcount_list[idx][1], last_batch_handle_list[idx] assert batch_count == batch_count_list[idx] + # roll_over option + dataiter.reset() + assert np.array_equal(dataiter.next().data[0].asnumpy(), cache) + + +def _test_shuffle(data, labels): + dataiter = mx.io.NDArrayIter(data, labels, 1, False) + batch_list = [] + for batch in dataiter: + # cache the original data + batch_list.append(batch.data[0].asnumpy()) + dataiter = mx.io.NDArrayIter(data, labels, 1, True) + idx_list = dataiter.idx + i = 0 + for batch in dataiter: + # check if each data point have been shuffled to corresponding positions + assert np.array_equal(batch.data[0].asnumpy(), batch_list[idx_list[i]]) + i += 1 + def test_NDArrayIter(): data, labels = _init_NDArrayIter_data() _test_last_batch_handle(data, labels) + _test_shuffle(data, labels) + def test_NDArrayIter_h5py(): if not h5py: @@ -149,6 +170,7 @@ def test_NDArrayIter_h5py(): except OSError: pass + def test_NDArrayIter_csr(): # creating toy data num_rows = rnd.randint(5, 15) @@ -168,12 +190,20 @@ def test_NDArrayIter_csr(): {'data': train_data}, dns, batch_size) except ImportError: pass + # scipy.sparse.csr_matrix with shuffle + num_batch = 0 + csr_iter = iter(mx.io.NDArrayIter({'data': train_data}, dns, batch_size, + shuffle=True, last_batch_handle='discard')) + for _ in csr_iter: + num_batch += 1 + + assert(num_batch == num_rows // batch_size) # CSRNDArray with shuffle csr_iter = iter(mx.io.NDArrayIter({'csr_data': csr, 'dns_data': dns}, dns, batch_size, shuffle=True, last_batch_handle='discard')) num_batch = 0 - for batch in csr_iter: + for _ in csr_iter: num_batch += 1 assert(num_batch == num_rows // batch_size) From 012a4196d4158d9c5c14016a2fbd59d69a3d9a66 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Wed, 22 Aug 2018 15:42:51 -0700 Subject: [PATCH 12/18] fix some wording and refine the variables naming --- python/mxnet/io.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/python/mxnet/io.py b/python/mxnet/io.py index ed4abda8c448..be201411aac6 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -704,7 +704,7 @@ def provide_label(self): def hard_reset(self): """Ignore roll over data and set to start.""" if self.shuffle: - self._shuffle() + self._shuffle_data() self.cursor = -self.batch_size self._cache_data = None self._cache_label = None @@ -712,7 +712,7 @@ def hard_reset(self): def reset(self): """Resets the iterator to the beginning of the data.""" if self.shuffle: - self._shuffle() + self._shuffle_data() # the range below indicate the last batch if self.last_batch_handle == 'roll_over' and \ self.num_data - self.batch_size < self.cursor < self.num_data: @@ -722,7 +722,8 @@ def reset(self): self.cursor = -self.batch_size def iter_next(self): - """Increments the coursor and check current cursor if exceed num of data.""" + """Increments the coursor by batch_size for next batch + and check current cursor if it exceed the number of data points.""" self.cursor += self.batch_size return self.cursor < self.num_data @@ -815,9 +816,11 @@ def getpad(self): else: return 0 - def _shuffle(self): + def _shuffle_data(self): """Shuffle the data.""" + # shuffle index np.random.shuffle(self.idx) + # get the data with corresponding index self.data = _shuffle(self.data, self.idx) self.label = _shuffle(self.label, self.idx) @@ -831,7 +834,7 @@ class MXDataIter(DataIter): underlying C++ data iterators. Usually you don't need to interact with `MXDataIter` directly unless you are - implementing your own data iterators in C++. To do that, please refer to + implementing your own data iterators in C+ +. To do that, please refer to examples under the `src/io` folder. Parameters From 68c10f24f9b926c1e9d59642b89e9b49e03f75be Mon Sep 17 00:00:00 2001 From: stu1130 Date: Wed, 22 Aug 2018 16:14:49 -0700 Subject: [PATCH 13/18] move utility function to new file --- python/mxnet/io_utils.py | 86 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 86 insertions(+) create mode 100644 python/mxnet/io_utils.py diff --git a/python/mxnet/io_utils.py b/python/mxnet/io_utils.py new file mode 100644 index 000000000000..dca4d743daa0 --- /dev/null +++ b/python/mxnet/io_utils.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""utility functions for io.py""" +from collections import OrderedDict + +import numpy as np +try: + import h5py +except ImportError: + h5py = None + +from .ndarray.sparse import CSRNDArray +from .ndarray.sparse import array as sparse_array +from .ndarray import NDArray +from .ndarray import array + +def init_data(data, allow_empty, default_name): + """Convert data into canonical form.""" + assert (data is not None) or allow_empty + if data is None: + data = [] + + if isinstance(data, (np.ndarray, NDArray, h5py.Dataset) + if h5py else (np.ndarray, NDArray)): + data = [data] + if isinstance(data, list): + if not allow_empty: + assert(len(data) > 0) + if len(data) == 1: + data = OrderedDict([(default_name, data[0])]) # pylint: disable=redefined-variable-type + else: + data = OrderedDict( # pylint: disable=redefined-variable-type + [('_%d_%s' % (i, default_name), d) for i, d in enumerate(data)]) + if not isinstance(data, dict): + raise TypeError("Input must be NDArray, numpy.ndarray, h5py.Dataset " + + "a list of them or dict with them as values") + for k, v in data.items(): + if not isinstance(v, (NDArray, h5py.Dataset) if h5py else NDArray): + try: + data[k] = array(v) + except: + raise TypeError(("Invalid type '%s' for %s, " % (type(v), k)) + + "should be NDArray, numpy.ndarray or h5py.Dataset") + + return list(sorted(data.items())) + + +def has_instance(data, dtype): + """Return True if ``data`` has instance of ``dtype``. + This function is called after _init_data. + ``data`` is a list of (str, NDArray)""" + for item in data: + _, arr = item + if isinstance(arr, dtype): + return True + return False + + +def shuffle(data, idx): + """Shuffle the data.""" + shuffle_data = [] + + for k, v in data: + if (isinstance(v, h5py.Dataset) if h5py else False): + shuffle_data.append((k, v)) + elif isinstance(v, CSRNDArray): + shuffle_data.append((k, sparse_array(v.asscipy()[idx], v.context))) + else: + shuffle_data.append((k, array(v.asnumpy()[idx], v.context))) + + return shuffle_data From 12059d4b6ad3006130e9c576132b63f7a9e0c5a5 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Wed, 22 Aug 2018 16:15:54 -0700 Subject: [PATCH 14/18] move utility function to io_utils.py --- python/mxnet/io.py | 75 +++++++--------------------------------------- 1 file changed, 10 insertions(+), 65 deletions(-) diff --git a/python/mxnet/io.py b/python/mxnet/io.py index be201411aac6..39245702b877 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -17,17 +17,14 @@ """Data iterators for common data formats.""" from __future__ import absolute_import -from collections import OrderedDict, namedtuple +from collections import namedtuple import sys import ctypes import logging import threading -try: - import h5py -except ImportError: - h5py = None import numpy as np + from .base import _LIB from .base import c_str_array, mx_uint, py_str from .base import DataIterHandle, NDArrayHandle @@ -35,11 +32,12 @@ from .base import check_call, build_param_doc as _build_param_doc from .ndarray import NDArray from .ndarray.sparse import CSRNDArray -from .ndarray.sparse import array as sparse_array from .ndarray import _ndarray_cls from .ndarray import array from .ndarray import concat +from .io_utils import init_data, has_instance, shuffle + class DataDesc(namedtuple('DataDesc', ['name', 'shape'])): """DataDesc is used to store name, shape, type and layout information of the data or the label. @@ -487,59 +485,6 @@ def getindex(self): def getpad(self): return self.current_batch.pad -def _init_data(data, allow_empty, default_name): - """Convert data into canonical form.""" - assert (data is not None) or allow_empty - if data is None: - data = [] - - if isinstance(data, (np.ndarray, NDArray, h5py.Dataset) - if h5py else (np.ndarray, NDArray)): - data = [data] - if isinstance(data, list): - if not allow_empty: - assert(len(data) > 0) - if len(data) == 1: - data = OrderedDict([(default_name, data[0])]) # pylint: disable=redefined-variable-type - else: - data = OrderedDict( # pylint: disable=redefined-variable-type - [('_%d_%s' % (i, default_name), d) for i, d in enumerate(data)]) - if not isinstance(data, dict): - raise TypeError("Input must be NDArray, numpy.ndarray, h5py.Dataset " + \ - "a list of them or dict with them as values") - for k, v in data.items(): - if not isinstance(v, (NDArray, h5py.Dataset) if h5py else NDArray): - try: - data[k] = array(v) - except: - raise TypeError(("Invalid type '%s' for %s, " % (type(v), k)) + \ - "should be NDArray, numpy.ndarray or h5py.Dataset") - - return list(sorted(data.items())) - -def _has_instance(data, dtype): - """Return True if ``data`` has instance of ``dtype``. - This function is called after _init_data. - ``data`` is a list of (str, NDArray)""" - for item in data: - _, arr = item - if isinstance(arr, dtype): - return True - return False - -def _shuffle(data, idx): - """Shuffle the data.""" - shuffle_data = [] - - for k, v in data: - if (isinstance(v, h5py.Dataset) if h5py else False): - shuffle_data.append((k, v)) - elif isinstance(v, CSRNDArray): - shuffle_data.append((k, sparse_array(v.asscipy()[idx], v.context))) - else: - shuffle_data.append((k, array(v.asnumpy()[idx], v.context))) - - return shuffle_data class NDArrayIter(DataIter): """Returns an iterator for ``mx.nd.NDArray``, ``numpy.ndarray``, ``h5py.Dataset`` @@ -662,10 +607,10 @@ def __init__(self, data, label=None, batch_size=1, shuffle=False, label_name='softmax_label'): super(NDArrayIter, self).__init__(batch_size) - self.data = _init_data(data, allow_empty=False, default_name=data_name) - self.label = _init_data(label, allow_empty=True, default_name=label_name) + self.data = init_data(data, allow_empty=False, default_name=data_name) + self.label = init_data(label, allow_empty=True, default_name=label_name) - if ((_has_instance(self.data, CSRNDArray) or _has_instance(self.label, CSRNDArray)) and + if ((has_instance(self.data, CSRNDArray) or has_instance(self.label, CSRNDArray)) and (last_batch_handle != 'discard')): raise NotImplementedError("`NDArrayIter` only supports ``CSRNDArray``" \ " with `last_batch_handle` set to `discard`.") @@ -820,9 +765,9 @@ def _shuffle_data(self): """Shuffle the data.""" # shuffle index np.random.shuffle(self.idx) - # get the data with corresponding index - self.data = _shuffle(self.data, self.idx) - self.label = _shuffle(self.label, self.idx) + # get the data by corresponding index + self.data = shuffle(self.data, self.idx) + self.label = shuffle(self.label, self.idx) class MXDataIter(DataIter): """A python wrapper a C++ data iterator. From c3f5b73d1f43ebd31b87d753cccded61c8122c50 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Wed, 22 Aug 2018 16:30:08 -0700 Subject: [PATCH 15/18] change shuffle function name to avoid redefining name --- python/mxnet/io.py | 6 +++--- python/mxnet/io_utils.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/python/mxnet/io.py b/python/mxnet/io.py index 39245702b877..38d0a39346d4 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -36,7 +36,7 @@ from .ndarray import array from .ndarray import concat -from .io_utils import init_data, has_instance, shuffle +from .io_utils import init_data, has_instance, getdata_by_idx class DataDesc(namedtuple('DataDesc', ['name', 'shape'])): """DataDesc is used to store name, shape, type and layout @@ -766,8 +766,8 @@ def _shuffle_data(self): # shuffle index np.random.shuffle(self.idx) # get the data by corresponding index - self.data = shuffle(self.data, self.idx) - self.label = shuffle(self.label, self.idx) + self.data = getdata_by_idx(self.data, self.idx) + self.label = getdata_by_idx(self.label, self.idx) class MXDataIter(DataIter): """A python wrapper a C++ data iterator. diff --git a/python/mxnet/io_utils.py b/python/mxnet/io_utils.py index dca4d743daa0..12251d4c19c0 100644 --- a/python/mxnet/io_utils.py +++ b/python/mxnet/io_utils.py @@ -71,7 +71,7 @@ def has_instance(data, dtype): return False -def shuffle(data, idx): +def getdata_by_idx(data, idx): """Shuffle the data.""" shuffle_data = [] From fa1b23dfb726b037a1296f2891f8001153782fe6 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Tue, 4 Sep 2018 14:49:13 -0700 Subject: [PATCH 16/18] make io as a module --- python/mxnet/io/__init__.py | 28 ++++++++++++++++++++++++++++ python/mxnet/{ => io}/io.py | 20 ++++++++++---------- python/mxnet/{ => io}/io_utils.py | 8 ++++---- 3 files changed, 42 insertions(+), 14 deletions(-) create mode 100644 python/mxnet/io/__init__.py rename python/mxnet/{ => io}/io.py (98%) rename python/mxnet/{ => io}/io_utils.py (95%) diff --git a/python/mxnet/io/__init__.py b/python/mxnet/io/__init__.py new file mode 100644 index 000000000000..71fa95f11529 --- /dev/null +++ b/python/mxnet/io/__init__.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +""" Data iterators for common data formats and utility functions.""" +from __future__ import absolute_import + +from . import io +from .io import * + +from . import io_utils +from .io_utils import * diff --git a/python/mxnet/io.py b/python/mxnet/io/io.py similarity index 98% rename from python/mxnet/io.py rename to python/mxnet/io/io.py index 38d0a39346d4..87eae7cb4bcf 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io/io.py @@ -25,16 +25,16 @@ import threading import numpy as np -from .base import _LIB -from .base import c_str_array, mx_uint, py_str -from .base import DataIterHandle, NDArrayHandle -from .base import mx_real_t -from .base import check_call, build_param_doc as _build_param_doc -from .ndarray import NDArray -from .ndarray.sparse import CSRNDArray -from .ndarray import _ndarray_cls -from .ndarray import array -from .ndarray import concat +from ..base import _LIB +from ..base import c_str_array, mx_uint, py_str +from ..base import DataIterHandle, NDArrayHandle +from ..base import mx_real_t +from ..base import check_call, build_param_doc as _build_param_doc +from ..ndarray import NDArray +from ..ndarray.sparse import CSRNDArray +from ..ndarray import _ndarray_cls +from ..ndarray import array +from ..ndarray import concat from .io_utils import init_data, has_instance, getdata_by_idx diff --git a/python/mxnet/io_utils.py b/python/mxnet/io/io_utils.py similarity index 95% rename from python/mxnet/io_utils.py rename to python/mxnet/io/io_utils.py index 12251d4c19c0..872e6410d7de 100644 --- a/python/mxnet/io_utils.py +++ b/python/mxnet/io/io_utils.py @@ -24,10 +24,10 @@ except ImportError: h5py = None -from .ndarray.sparse import CSRNDArray -from .ndarray.sparse import array as sparse_array -from .ndarray import NDArray -from .ndarray import array +from ..ndarray.sparse import CSRNDArray +from ..ndarray.sparse import array as sparse_array +from ..ndarray import NDArray +from ..ndarray import array def init_data(data, allow_empty, default_name): """Convert data into canonical form.""" From 4966e73ed8c5e629d49048a94468f593cee279b3 Mon Sep 17 00:00:00 2001 From: stu1130 Date: Tue, 4 Sep 2018 14:57:37 -0700 Subject: [PATCH 17/18] rename the utility functions --- python/mxnet/io/__init__.py | 4 ++-- python/mxnet/io/io.py | 2 +- python/mxnet/io/{io_utils.py => utils.py} | 0 3 files changed, 3 insertions(+), 3 deletions(-) rename python/mxnet/io/{io_utils.py => utils.py} (100%) diff --git a/python/mxnet/io/__init__.py b/python/mxnet/io/__init__.py index 71fa95f11529..d1542ffdada9 100644 --- a/python/mxnet/io/__init__.py +++ b/python/mxnet/io/__init__.py @@ -24,5 +24,5 @@ from . import io from .io import * -from . import io_utils -from .io_utils import * +from . import utils +from .utils import * diff --git a/python/mxnet/io/io.py b/python/mxnet/io/io.py index 87eae7cb4bcf..2ae3e70045fb 100644 --- a/python/mxnet/io/io.py +++ b/python/mxnet/io/io.py @@ -36,7 +36,7 @@ from ..ndarray import array from ..ndarray import concat -from .io_utils import init_data, has_instance, getdata_by_idx +from .utils import init_data, has_instance, getdata_by_idx class DataDesc(namedtuple('DataDesc', ['name', 'shape'])): """DataDesc is used to store name, shape, type and layout diff --git a/python/mxnet/io/io_utils.py b/python/mxnet/io/utils.py similarity index 100% rename from python/mxnet/io/io_utils.py rename to python/mxnet/io/utils.py From 2fb14f3515030e4eb8a6a7b90f1baaed38416abb Mon Sep 17 00:00:00 2001 From: stu1130 Date: Tue, 4 Sep 2018 15:04:56 -0700 Subject: [PATCH 18/18] disable wildcard-import --- python/mxnet/io/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/mxnet/io/__init__.py b/python/mxnet/io/__init__.py index d1542ffdada9..5c5e2e68d84a 100644 --- a/python/mxnet/io/__init__.py +++ b/python/mxnet/io/__init__.py @@ -18,6 +18,7 @@ # under the License. # coding: utf-8 +# pylint: disable=wildcard-import """ Data iterators for common data formats and utility functions.""" from __future__ import absolute_import