Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Change the way NDArrayIter handle the last batch #12285

Merged
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,4 @@ List of Contributors
* [Aaron Markham](https://github.com/aaronmarkham)
* [Sam Skalicky](https://github.com/samskalicky)
* [Per Goncalves da Silva](https://github.com/perdasilva)
* [Cheng-Che Lee](https://github.com/stu1130)
262 changes: 134 additions & 128 deletions python/mxnet/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,26 @@

"""Data iterators for common data formats."""
from __future__ import absolute_import
from collections import OrderedDict, namedtuple
from collections import namedtuple

import sys
import ctypes
import logging
import threading
try:
import h5py
except ImportError:
h5py = None
import numpy as np

from .base import _LIB
from .base import c_str_array, mx_uint, py_str
from .base import DataIterHandle, NDArrayHandle
from .base import mx_real_t
from .base import check_call, build_param_doc as _build_param_doc
from .ndarray import NDArray
from .ndarray.sparse import CSRNDArray
from .ndarray.sparse import array as sparse_array
from .ndarray import _ndarray_cls
from .ndarray import array
from .ndarray import concatenate
from .ndarray import arange
from .ndarray.random import shuffle as random_shuffle
from .ndarray import concat

from .io_utils import init_data, has_instance, getdata_by_idx
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not convinced to create a separate utils file, and this import will actually pollute the namespace. The original _private_function is good IMO.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I created another file is that the io.py file exceeds 1000 lines.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, make sense to me. Here's my suggestion in order to make things clear:

make a new directory named io under python/mxnet

.
├── __init__.py
├── io.py
└── io_utils.py

In __init__.py, reimport everything in io.py, so it won't break the original path.


class DataDesc(namedtuple('DataDesc', ['name', 'shape'])):
"""DataDesc is used to store name, shape, type and layout
Expand Down Expand Up @@ -489,59 +485,6 @@ def getindex(self):
def getpad(self):
return self.current_batch.pad

def _init_data(data, allow_empty, default_name):
"""Convert data into canonical form."""
assert (data is not None) or allow_empty
if data is None:
data = []

if isinstance(data, (np.ndarray, NDArray, h5py.Dataset)
if h5py else (np.ndarray, NDArray)):
data = [data]
if isinstance(data, list):
if not allow_empty:
assert(len(data) > 0)
if len(data) == 1:
data = OrderedDict([(default_name, data[0])]) # pylint: disable=redefined-variable-type
else:
data = OrderedDict( # pylint: disable=redefined-variable-type
[('_%d_%s' % (i, default_name), d) for i, d in enumerate(data)])
if not isinstance(data, dict):
raise TypeError("Input must be NDArray, numpy.ndarray, h5py.Dataset " + \
"a list of them or dict with them as values")
for k, v in data.items():
if not isinstance(v, (NDArray, h5py.Dataset) if h5py else NDArray):
try:
data[k] = array(v)
except:
raise TypeError(("Invalid type '%s' for %s, " % (type(v), k)) + \
"should be NDArray, numpy.ndarray or h5py.Dataset")

return list(sorted(data.items()))

def _has_instance(data, dtype):
"""Return True if ``data`` has instance of ``dtype``.
This function is called after _init_data.
``data`` is a list of (str, NDArray)"""
for item in data:
_, arr = item
if isinstance(arr, dtype):
return True
return False

def _shuffle(data, idx):
"""Shuffle the data."""
shuffle_data = []

for k, v in data:
if (isinstance(v, h5py.Dataset) if h5py else False):
shuffle_data.append((k, v))
elif isinstance(v, CSRNDArray):
shuffle_data.append((k, sparse_array(v.asscipy()[idx], v.context)))
else:
shuffle_data.append((k, array(v.asnumpy()[idx], v.context)))

return shuffle_data

class NDArrayIter(DataIter):
"""Returns an iterator for ``mx.nd.NDArray``, ``numpy.ndarray``, ``h5py.Dataset``
Expand Down Expand Up @@ -601,6 +544,22 @@ class NDArrayIter(DataIter):
...
>>> batchidx # Remaining examples are discarded. So, 10/3 batches are created.
3
>>> dataiter = mx.io.NDArrayIter(data, labels, 3, False, last_batch_handle='roll_over')
>>> batchidx = 0
>>> for batch in dataiter:
... batchidx += 1
...
>>> batchidx # Remaining examples are rolled over to the next iteration.
3
>>> dataiter.reset()
>>> dataiter.next().data[0].asnumpy()
[[[ 36. 37.]
[ 38. 39.]]
[[ 0. 1.]
[ 2. 3.]]
[[ 4. 5.]
[ 6. 7.]]]
(3L, 2L, 2L)

`NDArrayIter` also supports multiple input and labels.

Expand Down Expand Up @@ -633,8 +592,11 @@ class NDArrayIter(DataIter):
Only supported if no h5py.Dataset inputs are used.
last_batch_handle : str, optional
How to handle the last batch. This parameter can be 'pad', 'discard' or
'roll_over'. 'roll_over' is intended for training and can cause problems
if used for prediction.
'roll_over'.
If 'pad', the last batch will be padded with data starting from the begining
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how is pad and roll_over different, it is not clear in the documentation? In both it would seem you are taking data from the first batch of off the next epoch and adding it to the current last batch

Copy link
Contributor Author

@stu1130 stu1130 Aug 22, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let say data look like this [1,2,3,4,5,6,7,8,9,10] with batch_size 3
pad would be like [1,2,3],...[7,8,9],[10,1,2], while roll_over would be [1,2,3],...[7,8,9] and second iteration would be [10,1,2], [3,4,5], [6,7,8] after calling reset().
I've updated example starting from line 610

Copy link
Contributor

@chinakook chinakook Aug 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, It's so clear with an example.

If 'discard', the last batch will be discarded
If 'roll_over', the remaining elements will be rolled over to the next iteration and
note that it is intended for training and can cause problems if used for prediction.
data_name : str, optional
The data name.
label_name : str, optional
Expand All @@ -645,36 +607,28 @@ def __init__(self, data, label=None, batch_size=1, shuffle=False,
label_name='softmax_label'):
super(NDArrayIter, self).__init__(batch_size)

self.data = _init_data(data, allow_empty=False, default_name=data_name)
self.label = _init_data(label, allow_empty=True, default_name=label_name)
self.data = init_data(data, allow_empty=False, default_name=data_name)
self.label = init_data(label, allow_empty=True, default_name=label_name)

if ((_has_instance(self.data, CSRNDArray) or _has_instance(self.label, CSRNDArray)) and
if ((has_instance(self.data, CSRNDArray) or has_instance(self.label, CSRNDArray)) and
(last_batch_handle != 'discard')):
raise NotImplementedError("`NDArrayIter` only supports ``CSRNDArray``" \
" with `last_batch_handle` set to `discard`.")

# shuffle data
if shuffle:
tmp_idx = arange(self.data[0][1].shape[0], dtype=np.int32)
self.idx = random_shuffle(tmp_idx, out=tmp_idx).asnumpy()
self.data = _shuffle(self.data, self.idx)
self.label = _shuffle(self.label, self.idx)
else:
self.idx = np.arange(self.data[0][1].shape[0])

# batching
if last_batch_handle == 'discard':
new_n = self.data[0][1].shape[0] - self.data[0][1].shape[0] % batch_size
self.idx = self.idx[:new_n]
self.idx = np.arange(self.data[0][1].shape[0])
self.shuffle = shuffle
self.last_batch_handle = last_batch_handle
self.batch_size = batch_size
self.cursor = -self.batch_size
self.num_data = self.idx.shape[0]
# shuffle
self.reset()

self.data_list = [x[1] for x in self.data] + [x[1] for x in self.label]
self.num_source = len(self.data_list)
self.num_data = self.idx.shape[0]
assert self.num_data >= batch_size, \
"batch_size needs to be smaller than data size."
self.cursor = -batch_size
self.batch_size = batch_size
self.last_batch_handle = last_batch_handle
# used for 'roll_over'
self._cache_data = None
self._cache_label = None

@property
def provide_data(self):
Expand All @@ -694,74 +648,126 @@ def provide_label(self):

def hard_reset(self):
"""Ignore roll over data and set to start."""
if self.shuffle:
self._shuffle_data()
self.cursor = -self.batch_size
self._cache_data = None
self._cache_label = None

def reset(self):
if self.last_batch_handle == 'roll_over' and self.cursor > self.num_data:
self.cursor = -self.batch_size + (self.cursor%self.num_data)%self.batch_size
"""Resets the iterator to the beginning of the data."""
if self.shuffle:
self._shuffle_data()
# the range below indicate the last batch
if self.last_batch_handle == 'roll_over' and \
self.num_data - self.batch_size < self.cursor < self.num_data:
# (self.cursor - self.num_data) represents the data we have for the last batch
self.cursor = self.cursor - self.num_data - self.batch_size
else:
self.cursor = -self.batch_size

def iter_next(self):
"""Increments the coursor by batch_size for next batch
and check current cursor if it exceed the number of data points."""
self.cursor += self.batch_size
return self.cursor < self.num_data

def next(self):
if self.iter_next():
return DataBatch(data=self.getdata(), label=self.getlabel(), \
pad=self.getpad(), index=None)
else:
"""Returns the next batch of data."""
if not self.iter_next():
raise StopIteration
data = self.getdata()
label = self.getlabel()
# iter should stop when last batch is not complete
if data[0].shape[0] != self.batch_size:
# in this case, cache it for next epoch
self._cache_data = data
self._cache_label = label
raise StopIteration
return DataBatch(data=data, label=label, \
pad=self.getpad(), index=None)

def _getdata(self, data_source, start=None, end=None):
"""Load data from underlying arrays."""
assert start is not None or end is not None, 'should at least specify start or end'
start = start if start is not None else 0
end = end if end is not None else data_source[0][1].shape[0]
s = slice(start, end)
return [
x[1][s]
if isinstance(x[1], (np.ndarray, NDArray)) else
# h5py (only supports indices in increasing order)
array(x[1][sorted(self.idx[s])][[
list(self.idx[s]).index(i)
for i in sorted(self.idx[s])
]]) for x in data_source
]

def _getdata(self, data_source):
def _concat(self, first_data, second_data):
"""Helper function to concat two NDArrays."""
return [
concat(first_data[0], second_data[0], dim=0)
]

def _batchify(self, data_source):
"""Load data from underlying arrays, internal use only."""
assert(self.cursor < self.num_data), "DataIter needs reset."
if self.cursor + self.batch_size <= self.num_data:
return [
# np.ndarray or NDArray case
x[1][self.cursor:self.cursor + self.batch_size]
if isinstance(x[1], (np.ndarray, NDArray)) else
# h5py (only supports indices in increasing order)
array(x[1][sorted(self.idx[
self.cursor:self.cursor + self.batch_size])][[
list(self.idx[self.cursor:
self.cursor + self.batch_size]).index(i)
for i in sorted(self.idx[
self.cursor:self.cursor + self.batch_size])
]]) for x in data_source
]
else:
assert self.cursor < self.num_data, 'DataIter needs reset.'
# first batch of next epoch with 'roll_over'
if self.last_batch_handle == 'roll_over' and \
-self.batch_size < self.cursor < 0:
assert self._cache_data is not None or self._cache_label is not None, \
'next epoch should have cached data'
cache_data = self._cache_data if self._cache_data is not None else self._cache_label
second_data = self._getdata(
data_source, end=self.cursor + self.batch_size)
if self._cache_data is not None:
self._cache_data = None
else:
self._cache_label = None
return self._concat(cache_data, second_data)
# last batch with 'pad'
elif self.last_batch_handle == 'pad' and \
self.cursor + self.batch_size > self.num_data:
pad = self.batch_size - self.num_data + self.cursor
return [
# np.ndarray or NDArray case
concatenate([x[1][self.cursor:], x[1][:pad]])
if isinstance(x[1], (np.ndarray, NDArray)) else
# h5py (only supports indices in increasing order)
concatenate([
array(x[1][sorted(self.idx[self.cursor:])][[
list(self.idx[self.cursor:]).index(i)
for i in sorted(self.idx[self.cursor:])
]]),
array(x[1][sorted(self.idx[:pad])][[
list(self.idx[:pad]).index(i)
for i in sorted(self.idx[:pad])
]])
]) for x in data_source
]
first_data = self._getdata(data_source, start=self.cursor)
second_data = self._getdata(data_source, end=pad)
return self._concat(first_data, second_data)
# normal case
else:
if self.cursor + self.batch_size < self.num_data:
end_idx = self.cursor + self.batch_size
# get incomplete last batch
else:
end_idx = self.num_data
return self._getdata(data_source, self.cursor, end_idx)

def getdata(self):
return self._getdata(self.data)
"""Get data."""
return self._batchify(self.data)

def getlabel(self):
return self._getdata(self.label)
"""Get label."""
return self._batchify(self.label)

def getpad(self):
"""Get pad value of DataBatch."""
if self.last_batch_handle == 'pad' and \
self.cursor + self.batch_size > self.num_data:
return self.cursor + self.batch_size - self.num_data
# check the first batch
elif self.last_batch_handle == 'roll_over' and \
-self.batch_size < self.cursor < 0:
return -self.cursor
else:
return 0

def _shuffle_data(self):
"""Shuffle the data."""
# shuffle index
np.random.shuffle(self.idx)
# get the data by corresponding index
self.data = getdata_by_idx(self.data, self.idx)
self.label = getdata_by_idx(self.label, self.idx)

class MXDataIter(DataIter):
"""A python wrapper a C++ data iterator.
Expand All @@ -773,7 +779,7 @@ class MXDataIter(DataIter):
underlying C++ data iterators.

Usually you don't need to interact with `MXDataIter` directly unless you are
implementing your own data iterators in C++. To do that, please refer to
implementing your own data iterators in C+ +. To do that, please refer to
examples under the `src/io` folder.

Parameters
Expand Down
Loading