Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add Index to IO Batch #504

Merged
merged 26 commits into from
Nov 12, 2015
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,19 @@ MXNET_DLL int MXDataIterBeforeFirst(DataIterHandle handle);
*/
MXNET_DLL int MXDataIterGetData(DataIterHandle handle,
NDArrayHandle *out);
/*!
* \brief Get the image index by array
* \param handle the handle pointer to the data iterator
* \param batch size
* \return image index array
*/
MXNET_DLL int MXDataIterGetIndex(DataIterHandle handle,
uint64_t *out_index);
/*!
* \brief Get current batch size of iter
*/
MXNET_DLL int MXDataIterGetBatchsize(DataIterHandle handle,
mx_uint* batch_size);

/*!
* \brief Get the padding number in current data batch
Expand Down
2 changes: 2 additions & 0 deletions include/mxnet/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ struct DataInst {
struct DataBatch {
/*! \brief content of dense data, if this DataBatch is dense */
std::vector<NDArray> data;
/*! \brief index of image data */
std::vector<uint64_t> index;
/*! \brief extra data to be fed to the network */
std::string extra_data;
/*! \brief num of example padded to batch */
Expand Down
33 changes: 31 additions & 2 deletions python/mxnet/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,24 @@ def getlabel(self):
"""
return self.getdata(-1)

def getindex(self):
"""
Retures
-------
index : numpy.array
The index of current batch
"""
pass

def getbatchsize(self):
"""
Retures
-------
batch_size: int
The size of current batch
"""
pass

def getpad(self):
"""Get the number of padding examples in current batch.

Expand Down Expand Up @@ -239,12 +257,12 @@ def reset(self):

def next(self):
if self._debug_skip_load and not self._debug_at_begin:
return self.getdata(), self.getlabel()
return self.getdata(), self.getlabel(), self.getindex()
self._debug_at_begin = False
next_res = ctypes.c_int(0)
check_call(_LIB.MXDataIterNext(self.handle, ctypes.byref(next_res)))
if next_res.value:
return self.getdata(), self.getlabel()
return self.getdata(), self.getlabel(), self.getindex()
else:
raise StopIteration

Expand All @@ -263,6 +281,17 @@ def getlabel(self):
check_call(_LIB.MXDataIterGetLabel(self.handle, ctypes.byref(hdl)))
return NDArray(hdl, False)

def getindex(self):
batch_size = self.getbatchsize()
index = np.zeros((batch_size), dtype=np.uint64)
check_call(_LIB.MXDataIterGetIndex(self.handle, index.ctypes.data))
return index

def getbatchsize(self):
batch_size = ctypes.c_uint32(0)
check_call(_LIB.MXDataIterGetBatchsize(self.handle, ctypes.byref(batch_size)))
return batch_size.value

def getpad(self):
pad = ctypes.c_int(0)
check_call(_LIB.MXDataIterGetPadNum(self.handle, ctypes.byref(pad)))
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def _train_multi_device(symbol, ctx, input_shape,
nbatch = 0
# Iterate over training data.
while True:
for data, label in train_data:
for data, label, index in train_data:
# Copy data into the target
for target, islice in zip(arg_blocks[label_index], slices):
label[islice].copyto(target)
Expand Down
16 changes: 16 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,22 @@ int MXDataIterGetLabel(DataIterHandle handle, NDArrayHandle *out) {
API_END();
}

int MXDataIterGetIndex(DataIterHandle handle, uint64_t *out_index) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try to return in one function,

int MXDataIterGetIndex(DataIterHandle handle, uint64_t **out_index, size_t *out_size)

API_BEGIN();
const DataBatch& db = static_cast<IIterator<DataBatch>* >(handle)->Value();
for (size_t i = 0; i < db.index.size(); ++i) {
out_index[i] = db.index[i];
}
API_END();
}

MXNET_DLL int MXDataIterGetBatchsize(DataIterHandle handle, mx_uint* batch_size) {
API_BEGIN();
const DataBatch& db = static_cast<IIterator<DataBatch>* >(handle)->Value();
*batch_size = (mx_uint)db.index.size();
API_END();
}

int MXDataIterGetData(DataIterHandle handle, NDArrayHandle *out) {
API_BEGIN();
const DataBatch& db = static_cast<IIterator<DataBatch>* >(handle)->Value();
Expand Down
2 changes: 2 additions & 0 deletions src/io/iter_batchloader.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class BatchLoader : public IIterator<TBlobBatch> {
label_shape_ = TShape(label_shape_vec.begin(), label_shape_vec.end());
// Init space for out_
out_.inst_index = new unsigned[param_.batch_size];
out_.batch_size = param_.batch_size;
out_.data.clear();
data_holder_ = mshadow::NewTensor<mshadow::cpu>(data_shape_.get<4>(), 0.0f);
label_holder_ = mshadow::NewTensor<mshadow::cpu>(label_shape_.get<2>(), 0.0f);
Expand All @@ -88,6 +89,7 @@ class BatchLoader : public IIterator<TBlobBatch> {
}
inline bool Next(void) {
out_.num_batch_padd = 0;
out_.batch_size = param_.batch_size;
this->head_ = 0;

// if overflow from previous round, directly return false, until before first is called
Expand Down
5 changes: 4 additions & 1 deletion src/io/iter_prefetcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ class PrefetcherIter : public IIterator<DataBatch> {
iter_.Init([this](DataBatch **dptr) {
if (!loader_->Next()) return false;
const TBlobBatch& batch = loader_->Value();

if (*dptr == nullptr) {
// allocate databatch
*dptr = new DataBatch();
(*dptr)->num_batch_padd = batch.num_batch_padd;
(*dptr)->data.resize(batch.data.size());
(*dptr)->index.resize(batch.batch_size);
for (size_t i = 0; i < batch.data.size(); ++i) {
(*dptr)->data.at(i) = NDArray(batch.data[i].shape_, Context::CPU());
}
Expand All @@ -80,6 +80,9 @@ class PrefetcherIter : public IIterator<DataBatch> {
batch.data[i].FlatTo2D<cpu, real_t>());
(*dptr)->num_batch_padd = batch.num_batch_padd;
}
for (size_t i = 0; i < batch.batch_size; ++i) {
(*dptr)->index[i] = batch.inst_index[i];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use std::copy

}
return true;
},
[this]() { loader_->BeforeFirst(); });
Expand Down