-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Change the way NDArrayIter handle the last batch #12285
Changes from 15 commits
ce4ae1c
26a4c58
21d399b
c14b4ef
e8fb568
33290f6
3eb93c5
4ace292
e939b06
21ccec6
353afe7
012a419
68c10f2
12059d4
c3f5b73
9505ed3
fa1b23d
4966e73
2fb14f3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,30 +17,26 @@ | |
|
||
"""Data iterators for common data formats.""" | ||
from __future__ import absolute_import | ||
from collections import OrderedDict, namedtuple | ||
from collections import namedtuple | ||
|
||
import sys | ||
import ctypes | ||
import logging | ||
import threading | ||
try: | ||
import h5py | ||
except ImportError: | ||
h5py = None | ||
import numpy as np | ||
|
||
from .base import _LIB | ||
from .base import c_str_array, mx_uint, py_str | ||
from .base import DataIterHandle, NDArrayHandle | ||
from .base import mx_real_t | ||
from .base import check_call, build_param_doc as _build_param_doc | ||
from .ndarray import NDArray | ||
from .ndarray.sparse import CSRNDArray | ||
from .ndarray.sparse import array as sparse_array | ||
from .ndarray import _ndarray_cls | ||
from .ndarray import array | ||
from .ndarray import concatenate | ||
from .ndarray import arange | ||
from .ndarray.random import shuffle as random_shuffle | ||
from .ndarray import concat | ||
|
||
from .io_utils import init_data, has_instance, getdata_by_idx | ||
|
||
class DataDesc(namedtuple('DataDesc', ['name', 'shape'])): | ||
"""DataDesc is used to store name, shape, type and layout | ||
|
@@ -489,59 +485,6 @@ def getindex(self): | |
def getpad(self): | ||
return self.current_batch.pad | ||
|
||
def _init_data(data, allow_empty, default_name): | ||
"""Convert data into canonical form.""" | ||
assert (data is not None) or allow_empty | ||
if data is None: | ||
data = [] | ||
|
||
if isinstance(data, (np.ndarray, NDArray, h5py.Dataset) | ||
if h5py else (np.ndarray, NDArray)): | ||
data = [data] | ||
if isinstance(data, list): | ||
if not allow_empty: | ||
assert(len(data) > 0) | ||
if len(data) == 1: | ||
data = OrderedDict([(default_name, data[0])]) # pylint: disable=redefined-variable-type | ||
else: | ||
data = OrderedDict( # pylint: disable=redefined-variable-type | ||
[('_%d_%s' % (i, default_name), d) for i, d in enumerate(data)]) | ||
if not isinstance(data, dict): | ||
raise TypeError("Input must be NDArray, numpy.ndarray, h5py.Dataset " + \ | ||
"a list of them or dict with them as values") | ||
for k, v in data.items(): | ||
if not isinstance(v, (NDArray, h5py.Dataset) if h5py else NDArray): | ||
try: | ||
data[k] = array(v) | ||
except: | ||
raise TypeError(("Invalid type '%s' for %s, " % (type(v), k)) + \ | ||
"should be NDArray, numpy.ndarray or h5py.Dataset") | ||
|
||
return list(sorted(data.items())) | ||
|
||
def _has_instance(data, dtype): | ||
"""Return True if ``data`` has instance of ``dtype``. | ||
This function is called after _init_data. | ||
``data`` is a list of (str, NDArray)""" | ||
for item in data: | ||
_, arr = item | ||
if isinstance(arr, dtype): | ||
return True | ||
return False | ||
|
||
def _shuffle(data, idx): | ||
"""Shuffle the data.""" | ||
shuffle_data = [] | ||
|
||
for k, v in data: | ||
if (isinstance(v, h5py.Dataset) if h5py else False): | ||
shuffle_data.append((k, v)) | ||
elif isinstance(v, CSRNDArray): | ||
shuffle_data.append((k, sparse_array(v.asscipy()[idx], v.context))) | ||
else: | ||
shuffle_data.append((k, array(v.asnumpy()[idx], v.context))) | ||
|
||
return shuffle_data | ||
|
||
class NDArrayIter(DataIter): | ||
"""Returns an iterator for ``mx.nd.NDArray``, ``numpy.ndarray``, ``h5py.Dataset`` | ||
|
@@ -601,6 +544,22 @@ class NDArrayIter(DataIter): | |
... | ||
>>> batchidx # Remaining examples are discarded. So, 10/3 batches are created. | ||
3 | ||
>>> dataiter = mx.io.NDArrayIter(data, labels, 3, False, last_batch_handle='roll_over') | ||
>>> batchidx = 0 | ||
>>> for batch in dataiter: | ||
... batchidx += 1 | ||
... | ||
>>> batchidx # Remaining examples are rolled over to the next iteration. | ||
3 | ||
>>> dataiter.reset() | ||
>>> dataiter.next().data[0].asnumpy() | ||
[[[ 36. 37.] | ||
[ 38. 39.]] | ||
[[ 0. 1.] | ||
[ 2. 3.]] | ||
[[ 4. 5.] | ||
[ 6. 7.]]] | ||
(3L, 2L, 2L) | ||
|
||
`NDArrayIter` also supports multiple input and labels. | ||
|
||
|
@@ -633,8 +592,11 @@ class NDArrayIter(DataIter): | |
Only supported if no h5py.Dataset inputs are used. | ||
last_batch_handle : str, optional | ||
How to handle the last batch. This parameter can be 'pad', 'discard' or | ||
'roll_over'. 'roll_over' is intended for training and can cause problems | ||
if used for prediction. | ||
'roll_over'. | ||
If 'pad', the last batch will be padded with data starting from the begining | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. how is pad and roll_over different, it is not clear in the documentation? In both it would seem you are taking data from the first batch of off the next epoch and adding it to the current last batch There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let say data look like this [1,2,3,4,5,6,7,8,9,10] with batch_size 3 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, It's so clear with an example. |
||
If 'discard', the last batch will be discarded | ||
If 'roll_over', the remaining elements will be rolled over to the next iteration and | ||
note that it is intended for training and can cause problems if used for prediction. | ||
data_name : str, optional | ||
The data name. | ||
label_name : str, optional | ||
|
@@ -645,36 +607,28 @@ def __init__(self, data, label=None, batch_size=1, shuffle=False, | |
label_name='softmax_label'): | ||
super(NDArrayIter, self).__init__(batch_size) | ||
|
||
self.data = _init_data(data, allow_empty=False, default_name=data_name) | ||
self.label = _init_data(label, allow_empty=True, default_name=label_name) | ||
self.data = init_data(data, allow_empty=False, default_name=data_name) | ||
self.label = init_data(label, allow_empty=True, default_name=label_name) | ||
|
||
if ((_has_instance(self.data, CSRNDArray) or _has_instance(self.label, CSRNDArray)) and | ||
if ((has_instance(self.data, CSRNDArray) or has_instance(self.label, CSRNDArray)) and | ||
(last_batch_handle != 'discard')): | ||
raise NotImplementedError("`NDArrayIter` only supports ``CSRNDArray``" \ | ||
" with `last_batch_handle` set to `discard`.") | ||
|
||
# shuffle data | ||
if shuffle: | ||
tmp_idx = arange(self.data[0][1].shape[0], dtype=np.int32) | ||
self.idx = random_shuffle(tmp_idx, out=tmp_idx).asnumpy() | ||
self.data = _shuffle(self.data, self.idx) | ||
self.label = _shuffle(self.label, self.idx) | ||
else: | ||
self.idx = np.arange(self.data[0][1].shape[0]) | ||
|
||
# batching | ||
if last_batch_handle == 'discard': | ||
new_n = self.data[0][1].shape[0] - self.data[0][1].shape[0] % batch_size | ||
self.idx = self.idx[:new_n] | ||
self.idx = np.arange(self.data[0][1].shape[0]) | ||
self.shuffle = shuffle | ||
self.last_batch_handle = last_batch_handle | ||
self.batch_size = batch_size | ||
self.cursor = -self.batch_size | ||
self.num_data = self.idx.shape[0] | ||
# shuffle | ||
self.reset() | ||
|
||
self.data_list = [x[1] for x in self.data] + [x[1] for x in self.label] | ||
self.num_source = len(self.data_list) | ||
self.num_data = self.idx.shape[0] | ||
assert self.num_data >= batch_size, \ | ||
"batch_size needs to be smaller than data size." | ||
self.cursor = -batch_size | ||
self.batch_size = batch_size | ||
self.last_batch_handle = last_batch_handle | ||
# used for 'roll_over' | ||
self._cache_data = None | ||
self._cache_label = None | ||
|
||
@property | ||
def provide_data(self): | ||
|
@@ -694,74 +648,126 @@ def provide_label(self): | |
|
||
def hard_reset(self): | ||
"""Ignore roll over data and set to start.""" | ||
if self.shuffle: | ||
self._shuffle_data() | ||
self.cursor = -self.batch_size | ||
self._cache_data = None | ||
self._cache_label = None | ||
|
||
def reset(self): | ||
if self.last_batch_handle == 'roll_over' and self.cursor > self.num_data: | ||
self.cursor = -self.batch_size + (self.cursor%self.num_data)%self.batch_size | ||
"""Resets the iterator to the beginning of the data.""" | ||
if self.shuffle: | ||
self._shuffle_data() | ||
# the range below indicate the last batch | ||
if self.last_batch_handle == 'roll_over' and \ | ||
self.num_data - self.batch_size < self.cursor < self.num_data: | ||
# (self.cursor - self.num_data) represents the data we have for the last batch | ||
self.cursor = self.cursor - self.num_data - self.batch_size | ||
else: | ||
self.cursor = -self.batch_size | ||
|
||
def iter_next(self): | ||
"""Increments the coursor by batch_size for next batch | ||
and check current cursor if it exceed the number of data points.""" | ||
self.cursor += self.batch_size | ||
return self.cursor < self.num_data | ||
|
||
def next(self): | ||
if self.iter_next(): | ||
return DataBatch(data=self.getdata(), label=self.getlabel(), \ | ||
pad=self.getpad(), index=None) | ||
else: | ||
"""Returns the next batch of data.""" | ||
if not self.iter_next(): | ||
raise StopIteration | ||
data = self.getdata() | ||
label = self.getlabel() | ||
# iter should stop when last batch is not complete | ||
if data[0].shape[0] != self.batch_size: | ||
# in this case, cache it for next epoch | ||
self._cache_data = data | ||
self._cache_label = label | ||
raise StopIteration | ||
return DataBatch(data=data, label=label, \ | ||
pad=self.getpad(), index=None) | ||
|
||
def _getdata(self, data_source, start=None, end=None): | ||
"""Load data from underlying arrays.""" | ||
assert start is not None or end is not None, 'should at least specify start or end' | ||
start = start if start is not None else 0 | ||
end = end if end is not None else data_source[0][1].shape[0] | ||
s = slice(start, end) | ||
return [ | ||
x[1][s] | ||
if isinstance(x[1], (np.ndarray, NDArray)) else | ||
# h5py (only supports indices in increasing order) | ||
array(x[1][sorted(self.idx[s])][[ | ||
list(self.idx[s]).index(i) | ||
for i in sorted(self.idx[s]) | ||
]]) for x in data_source | ||
] | ||
|
||
def _getdata(self, data_source): | ||
def _concat(self, first_data, second_data): | ||
"""Helper function to concat two NDArrays.""" | ||
return [ | ||
concat(first_data[0], second_data[0], dim=0) | ||
] | ||
|
||
def _batchify(self, data_source): | ||
"""Load data from underlying arrays, internal use only.""" | ||
assert(self.cursor < self.num_data), "DataIter needs reset." | ||
if self.cursor + self.batch_size <= self.num_data: | ||
return [ | ||
# np.ndarray or NDArray case | ||
x[1][self.cursor:self.cursor + self.batch_size] | ||
if isinstance(x[1], (np.ndarray, NDArray)) else | ||
# h5py (only supports indices in increasing order) | ||
array(x[1][sorted(self.idx[ | ||
self.cursor:self.cursor + self.batch_size])][[ | ||
list(self.idx[self.cursor: | ||
self.cursor + self.batch_size]).index(i) | ||
for i in sorted(self.idx[ | ||
self.cursor:self.cursor + self.batch_size]) | ||
]]) for x in data_source | ||
] | ||
else: | ||
assert self.cursor < self.num_data, 'DataIter needs reset.' | ||
# first batch of next epoch with 'roll_over' | ||
if self.last_batch_handle == 'roll_over' and \ | ||
-self.batch_size < self.cursor < 0: | ||
assert self._cache_data is not None or self._cache_label is not None, \ | ||
'next epoch should have cached data' | ||
cache_data = self._cache_data if self._cache_data is not None else self._cache_label | ||
second_data = self._getdata( | ||
data_source, end=self.cursor + self.batch_size) | ||
if self._cache_data is not None: | ||
self._cache_data = None | ||
else: | ||
self._cache_label = None | ||
return self._concat(cache_data, second_data) | ||
# last batch with 'pad' | ||
elif self.last_batch_handle == 'pad' and \ | ||
self.cursor + self.batch_size > self.num_data: | ||
pad = self.batch_size - self.num_data + self.cursor | ||
return [ | ||
# np.ndarray or NDArray case | ||
concatenate([x[1][self.cursor:], x[1][:pad]]) | ||
if isinstance(x[1], (np.ndarray, NDArray)) else | ||
# h5py (only supports indices in increasing order) | ||
concatenate([ | ||
array(x[1][sorted(self.idx[self.cursor:])][[ | ||
list(self.idx[self.cursor:]).index(i) | ||
for i in sorted(self.idx[self.cursor:]) | ||
]]), | ||
array(x[1][sorted(self.idx[:pad])][[ | ||
list(self.idx[:pad]).index(i) | ||
for i in sorted(self.idx[:pad]) | ||
]]) | ||
]) for x in data_source | ||
] | ||
first_data = self._getdata(data_source, start=self.cursor) | ||
second_data = self._getdata(data_source, end=pad) | ||
return self._concat(first_data, second_data) | ||
# normal case | ||
else: | ||
if self.cursor + self.batch_size < self.num_data: | ||
end_idx = self.cursor + self.batch_size | ||
# get incomplete last batch | ||
else: | ||
end_idx = self.num_data | ||
return self._getdata(data_source, self.cursor, end_idx) | ||
|
||
def getdata(self): | ||
return self._getdata(self.data) | ||
"""Get data.""" | ||
return self._batchify(self.data) | ||
|
||
def getlabel(self): | ||
return self._getdata(self.label) | ||
"""Get label.""" | ||
return self._batchify(self.label) | ||
|
||
def getpad(self): | ||
"""Get pad value of DataBatch.""" | ||
if self.last_batch_handle == 'pad' and \ | ||
self.cursor + self.batch_size > self.num_data: | ||
return self.cursor + self.batch_size - self.num_data | ||
# check the first batch | ||
elif self.last_batch_handle == 'roll_over' and \ | ||
-self.batch_size < self.cursor < 0: | ||
return -self.cursor | ||
else: | ||
return 0 | ||
|
||
def _shuffle_data(self): | ||
"""Shuffle the data.""" | ||
# shuffle index | ||
np.random.shuffle(self.idx) | ||
# get the data by corresponding index | ||
self.data = getdata_by_idx(self.data, self.idx) | ||
self.label = getdata_by_idx(self.label, self.idx) | ||
|
||
class MXDataIter(DataIter): | ||
"""A python wrapper a C++ data iterator. | ||
|
@@ -773,7 +779,7 @@ class MXDataIter(DataIter): | |
underlying C++ data iterators. | ||
|
||
Usually you don't need to interact with `MXDataIter` directly unless you are | ||
implementing your own data iterators in C++. To do that, please refer to | ||
implementing your own data iterators in C+ +. To do that, please refer to | ||
examples under the `src/io` folder. | ||
|
||
Parameters | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not convinced to create a separate utils file, and this import will actually pollute the namespace. The original _private_function is good IMO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason I created another file is that the io.py file exceeds 1000 lines.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay, make sense to me. Here's my suggestion in order to make things clear:
make a new directory named
io
under python/mxnetIn
__init__.py
, reimport everything in io.py, so it won't break the original path.