Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add Index to IO Batch #504

Merged
merged 26 commits into from
Nov 12, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,14 @@ MXNET_DLL int MXDataIterBeforeFirst(DataIterHandle handle);
*/
MXNET_DLL int MXDataIterGetData(DataIterHandle handle,
NDArrayHandle *out);

/*!
* \brief Get the image index by array
* \param handle the handle pointer to the data iterator
* \return image index array and array size, index is const data
*/
MXNET_DLL int MXDataIterGetIndex(DataIterHandle handle,
uint64_t **out_index,
uint64_t *out_size);
/*!
* \brief Get the padding number in current data batch
* \param handle the handle pointer to the data iterator
Expand Down
2 changes: 2 additions & 0 deletions include/mxnet/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ struct DataInst {
struct DataBatch {
/*! \brief content of dense data, if this DataBatch is dense */
std::vector<NDArray> data;
/*! \brief index of image data */
std::vector<uint64_t> index;
/*! \brief extra data to be fed to the network */
std::string extra_data;
/*! \brief num of example padded to batch */
Expand Down
32 changes: 27 additions & 5 deletions python/mxnet/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,15 @@ def getlabel(self):
"""
return self.getdata(-1)

def getindex(self):
"""
Retures
-------
index : numpy.array
The index of current batch
"""
pass

def getpad(self):
"""Get the number of padding examples in current batch.
Returns
Expand All @@ -84,7 +93,7 @@ def getpad(self):
pass


DataBatch = namedtuple('DataBatch', ['data', 'label', 'pad'])
DataBatch = namedtuple('DataBatch', ['data', 'label', 'pad', 'index'])

def _init_data(data, allow_empty, default_name):
"""Convert data into canonical form."""
Expand Down Expand Up @@ -198,7 +207,8 @@ def iter_next(self):

def next(self):
if self.iter_next():
return DataBatch(data=self.getdata(), label=self.getlabel(), pad=self.getpad())
return DataBatch(data=self.getdata(), label=self.getlabel(), \
pad=self.getpad(), index=None)
else:
raise StopIteration

Expand Down Expand Up @@ -272,17 +282,18 @@ def reset(self):

def next(self):
if self._debug_skip_load and not self._debug_at_begin:
return DataBatch(data=[self.getdata()], label=[self.getlabel()], pad=self.getpad())
return DataBatch(data=[self.getdata()], label=[self.getlabel()], pad=self.getpad(),
index=self.getindex())
if self.first_batch is not None:
batch = self.first_batch
self.first_batch = None
return batch

self._debug_at_begin = False
next_res = ctypes.c_int(0)
check_call(_LIB.MXDataIterNext(self.handle, ctypes.byref(next_res)))
if next_res.value:
return DataBatch(data=[self.getdata()], label=[self.getlabel()], pad=self.getpad())
return DataBatch(data=[self.getdata()], label=[self.getlabel()], pad=self.getpad(),
index=self.getindex())
else:
raise StopIteration

Expand All @@ -303,6 +314,17 @@ def getlabel(self):
check_call(_LIB.MXDataIterGetLabel(self.handle, ctypes.byref(hdl)))
return NDArray(hdl, False)

def getindex(self):
index_size = ctypes.c_uint64(0)
index_data = ctypes.POINTER(ctypes.c_uint64)()
check_call(_LIB.MXDataIterGetIndex(self.handle,
ctypes.byref(index_data),
ctypes.byref(index_size)))
address = ctypes.addressof(index_data.contents)
dbuffer = (ctypes.c_uint64* index_size.value).from_address(address)
np_index = np.frombuffer(dbuffer, dtype=np.uint64)
return np_index.copy()

def getpad(self):
pad = ctypes.c_int(0)
check_call(_LIB.MXDataIterGetPadNum(self.handle, ctypes.byref(pad)))
Expand Down
8 changes: 8 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,14 @@ int MXDataIterGetLabel(DataIterHandle handle, NDArrayHandle *out) {
API_END();
}

int MXDataIterGetIndex(DataIterHandle handle, uint64_t **out_index, uint64_t *out_size) {
API_BEGIN();
const DataBatch& db = static_cast<IIterator<DataBatch>* >(handle)->Value();
*out_size = db.index.size();
*out_index = const_cast<uint64_t*>(db.index.data());
API_END();
}

int MXDataIterGetData(DataIterHandle handle, NDArrayHandle *out) {
API_BEGIN();
const DataBatch& db = static_cast<IIterator<DataBatch>* >(handle)->Value();
Expand Down
2 changes: 2 additions & 0 deletions src/io/iter_batchloader.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class BatchLoader : public IIterator<TBlobBatch> {
label_shape_ = TShape(label_shape_vec.begin(), label_shape_vec.end());
// Init space for out_
out_.inst_index = new unsigned[param_.batch_size];
out_.batch_size = param_.batch_size;
out_.data.clear();
data_holder_ = mshadow::NewTensor<mshadow::cpu>(data_shape_.get<4>(), 0.0f);
label_holder_ = mshadow::NewTensor<mshadow::cpu>(label_shape_.get<2>(), 0.0f);
Expand All @@ -88,6 +89,7 @@ class BatchLoader : public IIterator<TBlobBatch> {
}
inline bool Next(void) {
out_.num_batch_padd = 0;
out_.batch_size = param_.batch_size;
this->head_ = 0;

// if overflow from previous round, directly return false, until before first is called
Expand Down
10 changes: 8 additions & 2 deletions src/io/iter_prefetcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <string>
#include <vector>
#include <queue>
#include <algorithm>
#include "./inst_vector.h"

namespace mxnet {
Expand Down Expand Up @@ -62,12 +63,12 @@ class PrefetcherIter : public IIterator<DataBatch> {
iter_.Init([this](DataBatch **dptr) {
if (!loader_->Next()) return false;
const TBlobBatch& batch = loader_->Value();

if (*dptr == nullptr) {
// allocate databatch
*dptr = new DataBatch();
(*dptr)->num_batch_padd = batch.num_batch_padd;
(*dptr)->data.resize(batch.data.size());
(*dptr)->index.resize(batch.batch_size);
for (size_t i = 0; i < batch.data.size(); ++i) {
(*dptr)->data.at(i) = NDArray(batch.data[i].shape_, Context::CPU());
}
Expand All @@ -80,7 +81,12 @@ class PrefetcherIter : public IIterator<DataBatch> {
batch.data[i].FlatTo2D<cpu, real_t>());
(*dptr)->num_batch_padd = batch.num_batch_padd;
}
return true;
if (batch.inst_index) {
std::copy(batch.inst_index,
batch.inst_index + batch.batch_size,
(*dptr)->index.begin());
}
return true;
},
[this]() { loader_->BeforeFirst(); });
}
Expand Down