Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #504 from junranhe/master
Browse files Browse the repository at this point in the history
Add Index to IO Batch
  • Loading branch information
tqchen committed Nov 12, 2015
2 parents eb1e347 + 6492418 commit d3bd2d5
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 8 deletions.
9 changes: 8 additions & 1 deletion include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,14 @@ MXNET_DLL int MXDataIterBeforeFirst(DataIterHandle handle);
*/
MXNET_DLL int MXDataIterGetData(DataIterHandle handle,
NDArrayHandle *out);

/*!
* \brief Get the image index by array
* \param handle the handle pointer to the data iterator
* \return image index array and array size, index is const data
*/
MXNET_DLL int MXDataIterGetIndex(DataIterHandle handle,
uint64_t **out_index,
uint64_t *out_size);
/*!
* \brief Get the padding number in current data batch
* \param handle the handle pointer to the data iterator
Expand Down
2 changes: 2 additions & 0 deletions include/mxnet/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ struct DataInst {
struct DataBatch {
/*! \brief content of dense data, if this DataBatch is dense */
std::vector<NDArray> data;
/*! \brief index of image data */
std::vector<uint64_t> index;
/*! \brief extra data to be fed to the network */
std::string extra_data;
/*! \brief num of example padded to batch */
Expand Down
32 changes: 27 additions & 5 deletions python/mxnet/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,15 @@ def getlabel(self):
"""
return self.getdata(-1)

def getindex(self):
"""
Retures
-------
index : numpy.array
The index of current batch
"""
pass

def getpad(self):
"""Get the number of padding examples in current batch.
Returns
Expand All @@ -84,7 +93,7 @@ def getpad(self):
pass


DataBatch = namedtuple('DataBatch', ['data', 'label', 'pad'])
DataBatch = namedtuple('DataBatch', ['data', 'label', 'pad', 'index'])

def _init_data(data, allow_empty, default_name):
"""Convert data into canonical form."""
Expand Down Expand Up @@ -198,7 +207,8 @@ def iter_next(self):

def next(self):
if self.iter_next():
return DataBatch(data=self.getdata(), label=self.getlabel(), pad=self.getpad())
return DataBatch(data=self.getdata(), label=self.getlabel(), \
pad=self.getpad(), index=None)
else:
raise StopIteration

Expand Down Expand Up @@ -272,17 +282,18 @@ def reset(self):

def next(self):
if self._debug_skip_load and not self._debug_at_begin:
return DataBatch(data=[self.getdata()], label=[self.getlabel()], pad=self.getpad())
return DataBatch(data=[self.getdata()], label=[self.getlabel()], pad=self.getpad(),
index=self.getindex())
if self.first_batch is not None:
batch = self.first_batch
self.first_batch = None
return batch

self._debug_at_begin = False
next_res = ctypes.c_int(0)
check_call(_LIB.MXDataIterNext(self.handle, ctypes.byref(next_res)))
if next_res.value:
return DataBatch(data=[self.getdata()], label=[self.getlabel()], pad=self.getpad())
return DataBatch(data=[self.getdata()], label=[self.getlabel()], pad=self.getpad(),
index=self.getindex())
else:
raise StopIteration

Expand All @@ -303,6 +314,17 @@ def getlabel(self):
check_call(_LIB.MXDataIterGetLabel(self.handle, ctypes.byref(hdl)))
return NDArray(hdl, False)

def getindex(self):
index_size = ctypes.c_uint64(0)
index_data = ctypes.POINTER(ctypes.c_uint64)()
check_call(_LIB.MXDataIterGetIndex(self.handle,
ctypes.byref(index_data),
ctypes.byref(index_size)))
address = ctypes.addressof(index_data.contents)
dbuffer = (ctypes.c_uint64* index_size.value).from_address(address)
np_index = np.frombuffer(dbuffer, dtype=np.uint64)
return np_index.copy()

def getpad(self):
pad = ctypes.c_int(0)
check_call(_LIB.MXDataIterGetPadNum(self.handle, ctypes.byref(pad)))
Expand Down
8 changes: 8 additions & 0 deletions src/c_api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,14 @@ int MXDataIterGetLabel(DataIterHandle handle, NDArrayHandle *out) {
API_END();
}

int MXDataIterGetIndex(DataIterHandle handle, uint64_t **out_index, uint64_t *out_size) {
API_BEGIN();
const DataBatch& db = static_cast<IIterator<DataBatch>* >(handle)->Value();
*out_size = db.index.size();
*out_index = const_cast<uint64_t*>(db.index.data());
API_END();
}

int MXDataIterGetData(DataIterHandle handle, NDArrayHandle *out) {
API_BEGIN();
const DataBatch& db = static_cast<IIterator<DataBatch>* >(handle)->Value();
Expand Down
2 changes: 2 additions & 0 deletions src/io/iter_batchloader.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class BatchLoader : public IIterator<TBlobBatch> {
label_shape_ = TShape(label_shape_vec.begin(), label_shape_vec.end());
// Init space for out_
out_.inst_index = new unsigned[param_.batch_size];
out_.batch_size = param_.batch_size;
out_.data.clear();
data_holder_ = mshadow::NewTensor<mshadow::cpu>(data_shape_.get<4>(), 0.0f);
label_holder_ = mshadow::NewTensor<mshadow::cpu>(label_shape_.get<2>(), 0.0f);
Expand All @@ -88,6 +89,7 @@ class BatchLoader : public IIterator<TBlobBatch> {
}
inline bool Next(void) {
out_.num_batch_padd = 0;
out_.batch_size = param_.batch_size;
this->head_ = 0;

// if overflow from previous round, directly return false, until before first is called
Expand Down
10 changes: 8 additions & 2 deletions src/io/iter_prefetcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <string>
#include <vector>
#include <queue>
#include <algorithm>
#include "./inst_vector.h"

namespace mxnet {
Expand Down Expand Up @@ -62,12 +63,12 @@ class PrefetcherIter : public IIterator<DataBatch> {
iter_.Init([this](DataBatch **dptr) {
if (!loader_->Next()) return false;
const TBlobBatch& batch = loader_->Value();

if (*dptr == nullptr) {
// allocate databatch
*dptr = new DataBatch();
(*dptr)->num_batch_padd = batch.num_batch_padd;
(*dptr)->data.resize(batch.data.size());
(*dptr)->index.resize(batch.batch_size);
for (size_t i = 0; i < batch.data.size(); ++i) {
(*dptr)->data.at(i) = NDArray(batch.data[i].shape_, Context::CPU());
}
Expand All @@ -80,7 +81,12 @@ class PrefetcherIter : public IIterator<DataBatch> {
batch.data[i].FlatTo2D<cpu, real_t>());
(*dptr)->num_batch_padd = batch.num_batch_padd;
}
return true;
if (batch.inst_index) {
std::copy(batch.inst_index,
batch.inst_index + batch.batch_size,
(*dptr)->index.begin());
}
return true;
},
[this]() { loader_->BeforeFirst(); });
}
Expand Down

0 comments on commit d3bd2d5

Please sign in to comment.