From f71f1b1200b33b1e89509cff6c2c8d10da2553a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E4=BF=8A=E7=84=B6?= Date: Tue, 3 Nov 2015 17:32:45 +0800 Subject: [PATCH 01/23] add enable image index --- include/mxnet/c_api.h | 10 +++++++++- include/mxnet/io.h | 2 ++ src/c_api/c_api.cc | 8 ++++++++ src/io/iter_prefetcher.h | 2 ++ 4 files changed, 21 insertions(+), 1 deletion(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 911e7e31c0f1..849fb3d3301e 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -708,7 +708,15 @@ MXNET_DLL int MXDataIterBeforeFirst(DataIterHandle handle); */ MXNET_DLL int MXDataIterGetData(DataIterHandle handle, NDArrayHandle *out); - +/*! + * \brief Get the image index by array + * \param handle the handle pointer to the data iterator + * \param index array size + * \return image index array + */ +MXNET_DLL int MXDataIterGetIndex(DataIterHandle handle, + mx_uint index_size, + uint64_t *out_index); /*! * \brief Get the padding number in current data batch * \param handle the handle pointer to the data iterator diff --git a/include/mxnet/io.h b/include/mxnet/io.h index a53e47e32d80..b4429a951920 100644 --- a/include/mxnet/io.h +++ b/include/mxnet/io.h @@ -60,6 +60,8 @@ struct DataInst { struct DataBatch { /*! \brief content of dense data, if this DataBatch is dense */ std::vector data; + /*! \brief index of image data */ + std::vector index; /*! \brief extra data to be fed to the network */ std::string extra_data; /*! \brief num of example padded to batch */ diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 8706ac1cc86c..085c87a0691e 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -837,6 +837,14 @@ int MXDataIterGetLabel(DataIterHandle handle, NDArrayHandle *out) { *out = pndarray; API_END(); } +int MXDataIterGetIndex(DataIterHandle handle,mx_uint index_size, uint64_t *out_index) { + API_BEGIN(); + const DataBatch& db = static_cast* >(handle)->Value(); + for (size_t i = 0; i < db.index.size(); ++i) { + out_index[i] = db.index[i]; + } + API_END(); +} int MXDataIterGetData(DataIterHandle handle, NDArrayHandle *out) { API_BEGIN(); diff --git a/src/io/iter_prefetcher.h b/src/io/iter_prefetcher.h index b3bbdb40c07e..9f96d552c659 100644 --- a/src/io/iter_prefetcher.h +++ b/src/io/iter_prefetcher.h @@ -68,8 +68,10 @@ class PrefetcherIter : public IIterator { *dptr = new DataBatch(); (*dptr)->num_batch_padd = batch.num_batch_padd; (*dptr)->data.resize(batch.data.size()); + (*dptr)->index.resize(batch.data.size()); for (size_t i = 0; i < batch.data.size(); ++i) { (*dptr)->data.at(i) = NDArray(batch.data[i].shape_, Context::CPU()); + (*dptr)->index.at(i) = batch.inst_index[i]; } } CHECK(batch.data.size() == (*dptr)->data.size()); From 0bbff0c3568f4282c6578d2f8c92e8c4405c0833 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E4=BF=8A=E7=84=B6?= Date: Fri, 6 Nov 2015 18:22:37 +0800 Subject: [PATCH 02/23] add index --- include/mxnet/c_api.h | 9 +++++++-- ps-lite | 2 +- python/mxnet/io.py | 15 +++++++++++++-- python/mxnet/model.py | 2 +- src/c_api/c_api.cc | 10 +++++++++- src/io/iter_batchloader.h | 4 +++- src/io/iter_prefetcher.h | 12 +++++++++--- 7 files changed, 43 insertions(+), 11 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 849fb3d3301e..6047087f7062 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -711,12 +711,17 @@ MXNET_DLL int MXDataIterGetData(DataIterHandle handle, /*! * \brief Get the image index by array * \param handle the handle pointer to the data iterator - * \param index array size + * \param batch size * \return image index array */ MXNET_DLL int MXDataIterGetIndex(DataIterHandle handle, - mx_uint index_size, uint64_t *out_index); +/*! + * \brief Get current batch size of iter + */ +MXNET_DLL int MXDataIterGetBatchsize(DataIterHandle handle, + mx_uint* batch_size); + /*! * \brief Get the padding number in current data batch * \param handle the handle pointer to the data iterator diff --git a/ps-lite b/ps-lite index d6a64bd00eb9..7121aa1bdb67 160000 --- a/ps-lite +++ b/ps-lite @@ -1 +1 @@ -Subproject commit d6a64bd00eb975867e34e519d6e8522d970fe345 +Subproject commit 7121aa1bdb673f047c7600eb4347fd2911021710 diff --git a/python/mxnet/io.py b/python/mxnet/io.py index 29e95367c8c5..71b161ee1855 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -239,12 +239,12 @@ def reset(self): def next(self): if self._debug_skip_load and not self._debug_at_begin: - return self.getdata(), self.getlabel() + return self.getdata(), self.getlabel(), self.getindex() self._debug_at_begin = False next_res = ctypes.c_int(0) check_call(_LIB.MXDataIterNext(self.handle, ctypes.byref(next_res))) if next_res.value: - return self.getdata(), self.getlabel() + return self.getdata(), self.getlabel(), self.getindex() else: raise StopIteration @@ -263,6 +263,17 @@ def getlabel(self): check_call(_LIB.MXDataIterGetLabel(self.handle, ctypes.byref(hdl))) return NDArray(hdl, False) + def getindex(self): + batch_size = self.getbatchsize() + index = np.zeros((batch_size), dtype=np.uint64) + check_call(_LIB.MXDataIterGetIndex(self.handle, index.ctypes.data)) + return index + + def getbatchsize(self): + batch_size = ctypes.c_uint32(0) + check_call(_LIB.MXDataIterGetBatchsize(self.handle, ctypes.byref(batch_size))) + return batch_size.value + def getpad(self): pad = ctypes.c_int(0) check_call(_LIB.MXDataIterGetPadNum(self.handle, ctypes.byref(pad))) diff --git a/python/mxnet/model.py b/python/mxnet/model.py index dfb09b68c02a..a33edf86384c 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -301,7 +301,7 @@ def _train_multi_device(symbol, ctx, input_shape, nbatch = 0 # Iterate over training data. while True: - for data, label in train_data: + for data, label, index in train_data: # Copy data into the target for target, islice in zip(arg_blocks[label_index], slices): label[islice].copyto(target) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 085c87a0691e..97823d78406f 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -837,7 +837,8 @@ int MXDataIterGetLabel(DataIterHandle handle, NDArrayHandle *out) { *out = pndarray; API_END(); } -int MXDataIterGetIndex(DataIterHandle handle,mx_uint index_size, uint64_t *out_index) { + +int MXDataIterGetIndex(DataIterHandle handle, uint64_t *out_index) { API_BEGIN(); const DataBatch& db = static_cast* >(handle)->Value(); for (size_t i = 0; i < db.index.size(); ++i) { @@ -846,6 +847,13 @@ int MXDataIterGetIndex(DataIterHandle handle,mx_uint index_size, uint64_t *out_ API_END(); } +MXNET_DLL int MXDataIterGetBatchsize(DataIterHandle handle, mx_uint* batch_size) { + API_BEGIN(); + const DataBatch& db = static_cast* >(handle)->Value(); + *batch_size = (mx_uint)db.index.size(); + API_END(); +} + int MXDataIterGetData(DataIterHandle handle, NDArrayHandle *out) { API_BEGIN(); const DataBatch& db = static_cast* >(handle)->Value(); diff --git a/src/io/iter_batchloader.h b/src/io/iter_batchloader.h index 1ec28d13d65a..f4924f4c0eeb 100644 --- a/src/io/iter_batchloader.h +++ b/src/io/iter_batchloader.h @@ -69,6 +69,7 @@ class BatchLoader : public IIterator { label_shape_ = TShape(label_shape_vec.begin(), label_shape_vec.end()); // Init space for out_ out_.inst_index = new unsigned[param_.batch_size]; + out_.batch_size = param_.batch_size; out_.data.clear(); data_holder_ = mshadow::NewTensor(data_shape_.get<4>(), 0.0f); label_holder_ = mshadow::NewTensor(label_shape_.get<2>(), 0.0f); @@ -88,6 +89,7 @@ class BatchLoader : public IIterator { } inline bool Next(void) { out_.num_batch_padd = 0; + out_.batch_size = param_.batch_size; this->head_ = 0; // if overflow from previous round, directly return false, until before first is called @@ -101,7 +103,7 @@ class BatchLoader : public IIterator { d.data[1].get()); mshadow::Copy(out_.data[0].get()[top], d.data[0].get()); - if (++ top >= param_.batch_size) { + if (++ top >= param_.batch_size) { return true; } } diff --git a/src/io/iter_prefetcher.h b/src/io/iter_prefetcher.h index 9f96d552c659..03beec42bc0b 100644 --- a/src/io/iter_prefetcher.h +++ b/src/io/iter_prefetcher.h @@ -62,16 +62,15 @@ class PrefetcherIter : public IIterator { iter_.Init([this](DataBatch **dptr) { if (!loader_->Next()) return false; const TBlobBatch& batch = loader_->Value(); - + if (*dptr == nullptr) { // allocate databatch *dptr = new DataBatch(); (*dptr)->num_batch_padd = batch.num_batch_padd; (*dptr)->data.resize(batch.data.size()); - (*dptr)->index.resize(batch.data.size()); + (*dptr)->index.resize(batch.batch_size); for (size_t i = 0; i < batch.data.size(); ++i) { (*dptr)->data.at(i) = NDArray(batch.data[i].shape_, Context::CPU()); - (*dptr)->index.at(i) = batch.inst_index[i]; } } CHECK(batch.data.size() == (*dptr)->data.size()); @@ -82,6 +81,13 @@ class PrefetcherIter : public IIterator { batch.data[i].FlatTo2D()); (*dptr)->num_batch_padd = batch.num_batch_padd; } + for (size_t i = 0; i < batch.batch_size; ++i) { + (*dptr)->index[i] = batch.inst_index[i]; +// printf("(%d,%d)", int((*dptr)->index[i]), +// int((*dptr)->data[1].data().FlatTo2D()[i][0])); + } + +//printf("\n-------------------\n"); return true; }, [this]() { loader_->BeforeFirst(); }); From 97b3d099e3f45f550b07cb31d7c9dffa3bd42561 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E4=BF=8A=E7=84=B6?= Date: Fri, 6 Nov 2015 18:50:32 +0800 Subject: [PATCH 03/23] fix lint --- include/mxnet/c_api.h | 4 ++-- python/mxnet/io.py | 18 ++++++++++++++++++ src/io/iter_batchloader.h | 6 +++--- src/io/iter_prefetcher.h | 7 +------ 4 files changed, 24 insertions(+), 11 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 6047087f7062..ce5dbad8e591 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -715,12 +715,12 @@ MXNET_DLL int MXDataIterGetData(DataIterHandle handle, * \return image index array */ MXNET_DLL int MXDataIterGetIndex(DataIterHandle handle, - uint64_t *out_index); + uint64_t *out_index); /*! * \brief Get current batch size of iter */ MXNET_DLL int MXDataIterGetBatchsize(DataIterHandle handle, - mx_uint* batch_size); + mx_uint* batch_size); /*! * \brief Get the padding number in current data batch diff --git a/python/mxnet/io.py b/python/mxnet/io.py index 71b161ee1855..5e273890951b 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -79,6 +79,24 @@ def getlabel(self): """ return self.getdata(-1) + def getindex(self): + """ + Retures + ------- + index : numpy.array + The index of current batch + """ + pass + + def getbatchsize(self): + """ + Retures + ------- + batch_size: int + The size of current batch + """ + pass + def getpad(self): """Get the number of padding examples in current batch. diff --git a/src/io/iter_batchloader.h b/src/io/iter_batchloader.h index f4924f4c0eeb..57db5f2d7846 100644 --- a/src/io/iter_batchloader.h +++ b/src/io/iter_batchloader.h @@ -69,7 +69,7 @@ class BatchLoader : public IIterator { label_shape_ = TShape(label_shape_vec.begin(), label_shape_vec.end()); // Init space for out_ out_.inst_index = new unsigned[param_.batch_size]; - out_.batch_size = param_.batch_size; + out_.batch_size = param_.batch_size; out_.data.clear(); data_holder_ = mshadow::NewTensor(data_shape_.get<4>(), 0.0f); label_holder_ = mshadow::NewTensor(label_shape_.get<2>(), 0.0f); @@ -89,7 +89,7 @@ class BatchLoader : public IIterator { } inline bool Next(void) { out_.num_batch_padd = 0; - out_.batch_size = param_.batch_size; + out_.batch_size = param_.batch_size; this->head_ = 0; // if overflow from previous round, directly return false, until before first is called @@ -103,7 +103,7 @@ class BatchLoader : public IIterator { d.data[1].get()); mshadow::Copy(out_.data[0].get()[top], d.data[0].get()); - if (++ top >= param_.batch_size) { + if (++ top >= param_.batch_size) { return true; } } diff --git a/src/io/iter_prefetcher.h b/src/io/iter_prefetcher.h index 03beec42bc0b..ea4c3b08cd9f 100644 --- a/src/io/iter_prefetcher.h +++ b/src/io/iter_prefetcher.h @@ -62,13 +62,12 @@ class PrefetcherIter : public IIterator { iter_.Init([this](DataBatch **dptr) { if (!loader_->Next()) return false; const TBlobBatch& batch = loader_->Value(); - if (*dptr == nullptr) { // allocate databatch *dptr = new DataBatch(); (*dptr)->num_batch_padd = batch.num_batch_padd; (*dptr)->data.resize(batch.data.size()); - (*dptr)->index.resize(batch.batch_size); + (*dptr)->index.resize(batch.batch_size); for (size_t i = 0; i < batch.data.size(); ++i) { (*dptr)->data.at(i) = NDArray(batch.data[i].shape_, Context::CPU()); } @@ -83,11 +82,7 @@ class PrefetcherIter : public IIterator { } for (size_t i = 0; i < batch.batch_size; ++i) { (*dptr)->index[i] = batch.inst_index[i]; -// printf("(%d,%d)", int((*dptr)->index[i]), -// int((*dptr)->data[1].data().FlatTo2D()[i][0])); } - -//printf("\n-------------------\n"); return true; }, [this]() { loader_->BeforeFirst(); }); From d8cb825a984c68304874635a1c5f132835d8c758 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E4=BF=8A=E7=84=B6?= Date: Tue, 3 Nov 2015 17:32:45 +0800 Subject: [PATCH 04/23] add enable image index --- include/mxnet/c_api.h | 10 +++++++++- include/mxnet/io.h | 2 ++ src/c_api/c_api.cc | 8 ++++++++ src/io/iter_prefetcher.h | 2 ++ 4 files changed, 21 insertions(+), 1 deletion(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 911e7e31c0f1..849fb3d3301e 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -708,7 +708,15 @@ MXNET_DLL int MXDataIterBeforeFirst(DataIterHandle handle); */ MXNET_DLL int MXDataIterGetData(DataIterHandle handle, NDArrayHandle *out); - +/*! + * \brief Get the image index by array + * \param handle the handle pointer to the data iterator + * \param index array size + * \return image index array + */ +MXNET_DLL int MXDataIterGetIndex(DataIterHandle handle, + mx_uint index_size, + uint64_t *out_index); /*! * \brief Get the padding number in current data batch * \param handle the handle pointer to the data iterator diff --git a/include/mxnet/io.h b/include/mxnet/io.h index a53e47e32d80..b4429a951920 100644 --- a/include/mxnet/io.h +++ b/include/mxnet/io.h @@ -60,6 +60,8 @@ struct DataInst { struct DataBatch { /*! \brief content of dense data, if this DataBatch is dense */ std::vector data; + /*! \brief index of image data */ + std::vector index; /*! \brief extra data to be fed to the network */ std::string extra_data; /*! \brief num of example padded to batch */ diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 8706ac1cc86c..085c87a0691e 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -837,6 +837,14 @@ int MXDataIterGetLabel(DataIterHandle handle, NDArrayHandle *out) { *out = pndarray; API_END(); } +int MXDataIterGetIndex(DataIterHandle handle,mx_uint index_size, uint64_t *out_index) { + API_BEGIN(); + const DataBatch& db = static_cast* >(handle)->Value(); + for (size_t i = 0; i < db.index.size(); ++i) { + out_index[i] = db.index[i]; + } + API_END(); +} int MXDataIterGetData(DataIterHandle handle, NDArrayHandle *out) { API_BEGIN(); diff --git a/src/io/iter_prefetcher.h b/src/io/iter_prefetcher.h index b3bbdb40c07e..9f96d552c659 100644 --- a/src/io/iter_prefetcher.h +++ b/src/io/iter_prefetcher.h @@ -68,8 +68,10 @@ class PrefetcherIter : public IIterator { *dptr = new DataBatch(); (*dptr)->num_batch_padd = batch.num_batch_padd; (*dptr)->data.resize(batch.data.size()); + (*dptr)->index.resize(batch.data.size()); for (size_t i = 0; i < batch.data.size(); ++i) { (*dptr)->data.at(i) = NDArray(batch.data[i].shape_, Context::CPU()); + (*dptr)->index.at(i) = batch.inst_index[i]; } } CHECK(batch.data.size() == (*dptr)->data.size()); From 278dfade5c2e2aa555371b094e219ff51e4cf465 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E4=BF=8A=E7=84=B6?= Date: Fri, 6 Nov 2015 18:22:37 +0800 Subject: [PATCH 05/23] add index --- include/mxnet/c_api.h | 9 +++++++-- python/mxnet/io.py | 12 +++++++++++- src/c_api/c_api.cc | 10 +++++++++- src/io/iter_batchloader.h | 4 +++- src/io/iter_prefetcher.h | 12 +++++++++--- 5 files changed, 39 insertions(+), 8 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 849fb3d3301e..6047087f7062 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -711,12 +711,17 @@ MXNET_DLL int MXDataIterGetData(DataIterHandle handle, /*! * \brief Get the image index by array * \param handle the handle pointer to the data iterator - * \param index array size + * \param batch size * \return image index array */ MXNET_DLL int MXDataIterGetIndex(DataIterHandle handle, - mx_uint index_size, uint64_t *out_index); +/*! + * \brief Get current batch size of iter + */ +MXNET_DLL int MXDataIterGetBatchsize(DataIterHandle handle, + mx_uint* batch_size); + /*! * \brief Get the padding number in current data batch * \param handle the handle pointer to the data iterator diff --git a/python/mxnet/io.py b/python/mxnet/io.py index cfef883e62e5..f41488ac6633 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -277,7 +277,6 @@ def next(self): batch = self.first_batch self.first_batch = None return batch - self._debug_at_begin = False next_res = ctypes.c_int(0) check_call(_LIB.MXDataIterNext(self.handle, ctypes.byref(next_res))) @@ -303,6 +302,17 @@ def getlabel(self): check_call(_LIB.MXDataIterGetLabel(self.handle, ctypes.byref(hdl))) return NDArray(hdl, False) + def getindex(self): + batch_size = self.getbatchsize() + index = np.zeros((batch_size), dtype=np.uint64) + check_call(_LIB.MXDataIterGetIndex(self.handle, index.ctypes.data)) + return index + + def getbatchsize(self): + batch_size = ctypes.c_uint32(0) + check_call(_LIB.MXDataIterGetBatchsize(self.handle, ctypes.byref(batch_size))) + return batch_size.value + def getpad(self): pad = ctypes.c_int(0) check_call(_LIB.MXDataIterGetPadNum(self.handle, ctypes.byref(pad))) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 085c87a0691e..97823d78406f 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -837,7 +837,8 @@ int MXDataIterGetLabel(DataIterHandle handle, NDArrayHandle *out) { *out = pndarray; API_END(); } -int MXDataIterGetIndex(DataIterHandle handle,mx_uint index_size, uint64_t *out_index) { + +int MXDataIterGetIndex(DataIterHandle handle, uint64_t *out_index) { API_BEGIN(); const DataBatch& db = static_cast* >(handle)->Value(); for (size_t i = 0; i < db.index.size(); ++i) { @@ -846,6 +847,13 @@ int MXDataIterGetIndex(DataIterHandle handle,mx_uint index_size, uint64_t *out_ API_END(); } +MXNET_DLL int MXDataIterGetBatchsize(DataIterHandle handle, mx_uint* batch_size) { + API_BEGIN(); + const DataBatch& db = static_cast* >(handle)->Value(); + *batch_size = (mx_uint)db.index.size(); + API_END(); +} + int MXDataIterGetData(DataIterHandle handle, NDArrayHandle *out) { API_BEGIN(); const DataBatch& db = static_cast* >(handle)->Value(); diff --git a/src/io/iter_batchloader.h b/src/io/iter_batchloader.h index 1ec28d13d65a..f4924f4c0eeb 100644 --- a/src/io/iter_batchloader.h +++ b/src/io/iter_batchloader.h @@ -69,6 +69,7 @@ class BatchLoader : public IIterator { label_shape_ = TShape(label_shape_vec.begin(), label_shape_vec.end()); // Init space for out_ out_.inst_index = new unsigned[param_.batch_size]; + out_.batch_size = param_.batch_size; out_.data.clear(); data_holder_ = mshadow::NewTensor(data_shape_.get<4>(), 0.0f); label_holder_ = mshadow::NewTensor(label_shape_.get<2>(), 0.0f); @@ -88,6 +89,7 @@ class BatchLoader : public IIterator { } inline bool Next(void) { out_.num_batch_padd = 0; + out_.batch_size = param_.batch_size; this->head_ = 0; // if overflow from previous round, directly return false, until before first is called @@ -101,7 +103,7 @@ class BatchLoader : public IIterator { d.data[1].get()); mshadow::Copy(out_.data[0].get()[top], d.data[0].get()); - if (++ top >= param_.batch_size) { + if (++ top >= param_.batch_size) { return true; } } diff --git a/src/io/iter_prefetcher.h b/src/io/iter_prefetcher.h index 9f96d552c659..03beec42bc0b 100644 --- a/src/io/iter_prefetcher.h +++ b/src/io/iter_prefetcher.h @@ -62,16 +62,15 @@ class PrefetcherIter : public IIterator { iter_.Init([this](DataBatch **dptr) { if (!loader_->Next()) return false; const TBlobBatch& batch = loader_->Value(); - + if (*dptr == nullptr) { // allocate databatch *dptr = new DataBatch(); (*dptr)->num_batch_padd = batch.num_batch_padd; (*dptr)->data.resize(batch.data.size()); - (*dptr)->index.resize(batch.data.size()); + (*dptr)->index.resize(batch.batch_size); for (size_t i = 0; i < batch.data.size(); ++i) { (*dptr)->data.at(i) = NDArray(batch.data[i].shape_, Context::CPU()); - (*dptr)->index.at(i) = batch.inst_index[i]; } } CHECK(batch.data.size() == (*dptr)->data.size()); @@ -82,6 +81,13 @@ class PrefetcherIter : public IIterator { batch.data[i].FlatTo2D()); (*dptr)->num_batch_padd = batch.num_batch_padd; } + for (size_t i = 0; i < batch.batch_size; ++i) { + (*dptr)->index[i] = batch.inst_index[i]; +// printf("(%d,%d)", int((*dptr)->index[i]), +// int((*dptr)->data[1].data().FlatTo2D()[i][0])); + } + +//printf("\n-------------------\n"); return true; }, [this]() { loader_->BeforeFirst(); }); From 950532318cd0a3238221c1f48d0bf507edc56d08 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E4=BF=8A=E7=84=B6?= Date: Fri, 6 Nov 2015 18:50:32 +0800 Subject: [PATCH 06/23] fix lint --- include/mxnet/c_api.h | 4 ++-- python/mxnet/io.py | 18 ++++++++++++++++++ src/io/iter_batchloader.h | 6 +++--- src/io/iter_prefetcher.h | 7 +------ 4 files changed, 24 insertions(+), 11 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 6047087f7062..ce5dbad8e591 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -715,12 +715,12 @@ MXNET_DLL int MXDataIterGetData(DataIterHandle handle, * \return image index array */ MXNET_DLL int MXDataIterGetIndex(DataIterHandle handle, - uint64_t *out_index); + uint64_t *out_index); /*! * \brief Get current batch size of iter */ MXNET_DLL int MXDataIterGetBatchsize(DataIterHandle handle, - mx_uint* batch_size); + mx_uint* batch_size); /*! * \brief Get the padding number in current data batch diff --git a/python/mxnet/io.py b/python/mxnet/io.py index f41488ac6633..9f3ec2b74d9a 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -74,6 +74,24 @@ def getlabel(self): """ return self.getdata(-1) + def getindex(self): + """ + Retures + ------- + index : numpy.array + The index of current batch + """ + pass + + def getbatchsize(self): + """ + Retures + ------- + batch_size: int + The size of current batch + """ + pass + def getpad(self): """Get the number of padding examples in current batch. Returns diff --git a/src/io/iter_batchloader.h b/src/io/iter_batchloader.h index f4924f4c0eeb..57db5f2d7846 100644 --- a/src/io/iter_batchloader.h +++ b/src/io/iter_batchloader.h @@ -69,7 +69,7 @@ class BatchLoader : public IIterator { label_shape_ = TShape(label_shape_vec.begin(), label_shape_vec.end()); // Init space for out_ out_.inst_index = new unsigned[param_.batch_size]; - out_.batch_size = param_.batch_size; + out_.batch_size = param_.batch_size; out_.data.clear(); data_holder_ = mshadow::NewTensor(data_shape_.get<4>(), 0.0f); label_holder_ = mshadow::NewTensor(label_shape_.get<2>(), 0.0f); @@ -89,7 +89,7 @@ class BatchLoader : public IIterator { } inline bool Next(void) { out_.num_batch_padd = 0; - out_.batch_size = param_.batch_size; + out_.batch_size = param_.batch_size; this->head_ = 0; // if overflow from previous round, directly return false, until before first is called @@ -103,7 +103,7 @@ class BatchLoader : public IIterator { d.data[1].get()); mshadow::Copy(out_.data[0].get()[top], d.data[0].get()); - if (++ top >= param_.batch_size) { + if (++ top >= param_.batch_size) { return true; } } diff --git a/src/io/iter_prefetcher.h b/src/io/iter_prefetcher.h index 03beec42bc0b..ea4c3b08cd9f 100644 --- a/src/io/iter_prefetcher.h +++ b/src/io/iter_prefetcher.h @@ -62,13 +62,12 @@ class PrefetcherIter : public IIterator { iter_.Init([this](DataBatch **dptr) { if (!loader_->Next()) return false; const TBlobBatch& batch = loader_->Value(); - if (*dptr == nullptr) { // allocate databatch *dptr = new DataBatch(); (*dptr)->num_batch_padd = batch.num_batch_padd; (*dptr)->data.resize(batch.data.size()); - (*dptr)->index.resize(batch.batch_size); + (*dptr)->index.resize(batch.batch_size); for (size_t i = 0; i < batch.data.size(); ++i) { (*dptr)->data.at(i) = NDArray(batch.data[i].shape_, Context::CPU()); } @@ -83,11 +82,7 @@ class PrefetcherIter : public IIterator { } for (size_t i = 0; i < batch.batch_size; ++i) { (*dptr)->index[i] = batch.inst_index[i]; -// printf("(%d,%d)", int((*dptr)->index[i]), -// int((*dptr)->data[1].data().FlatTo2D()[i][0])); } - -//printf("\n-------------------\n"); return true; }, [this]() { loader_->BeforeFirst(); }); From edd9b9c0187a283378e353f9d70cea9a99cddef4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E4=BF=8A=E7=84=B6?= Date: Mon, 9 Nov 2015 17:03:52 +0800 Subject: [PATCH 07/23] change c_api index interface --- include/mxnet/c_api.h | 12 +++--------- python/mxnet/io.py | 34 +++++++++++++--------------------- src/c_api/c_api.cc | 14 +++----------- src/io/iter_prefetcher.h | 6 +++--- 4 files changed, 22 insertions(+), 44 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index ce5dbad8e591..f82fb8ddcdbc 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -711,17 +711,11 @@ MXNET_DLL int MXDataIterGetData(DataIterHandle handle, /*! * \brief Get the image index by array * \param handle the handle pointer to the data iterator - * \param batch size - * \return image index array + * \return image index array and array size, index is const data */ MXNET_DLL int MXDataIterGetIndex(DataIterHandle handle, - uint64_t *out_index); -/*! - * \brief Get current batch size of iter - */ -MXNET_DLL int MXDataIterGetBatchsize(DataIterHandle handle, - mx_uint* batch_size); - + uint64_t **out_index, + uint64_t *out_size); /*! * \brief Get the padding number in current data batch * \param handle the handle pointer to the data iterator diff --git a/python/mxnet/io.py b/python/mxnet/io.py index 9f3ec2b74d9a..9f1b5b9ef89e 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -83,15 +83,6 @@ def getindex(self): """ pass - def getbatchsize(self): - """ - Retures - ------- - batch_size: int - The size of current batch - """ - pass - def getpad(self): """Get the number of padding examples in current batch. Returns @@ -102,7 +93,7 @@ def getpad(self): pass -DataBatch = namedtuple('DataBatch', ['data', 'label', 'pad']) +DataBatch = namedtuple('DataBatch', ['data', 'label', 'pad', 'index']) def _init_data(data, allow_empty, default_name): """Convert data into canonical form.""" @@ -290,7 +281,7 @@ def reset(self): def next(self): if self._debug_skip_load and not self._debug_at_begin: - return DataBatch(data=[self.getdata()], label=[self.getlabel()], pad=self.getpad()) + return DataBatch(data=[self.getdata()], label=[self.getlabel()], pad=self.getpad(), index=self.getindex()) if self.first_batch is not None: batch = self.first_batch self.first_batch = None @@ -299,7 +290,9 @@ def next(self): next_res = ctypes.c_int(0) check_call(_LIB.MXDataIterNext(self.handle, ctypes.byref(next_res))) if next_res.value: - return DataBatch(data=[self.getdata()], label=[self.getlabel()], pad=self.getpad()) + print 'label', self.getlabel().asnumpy() + print 'index', self.getindex() + return DataBatch(data=[self.getdata()], label=[self.getlabel()], pad=self.getpad(), index=self.getindex()) else: raise StopIteration @@ -321,15 +314,14 @@ def getlabel(self): return NDArray(hdl, False) def getindex(self): - batch_size = self.getbatchsize() - index = np.zeros((batch_size), dtype=np.uint64) - check_call(_LIB.MXDataIterGetIndex(self.handle, index.ctypes.data)) - return index - - def getbatchsize(self): - batch_size = ctypes.c_uint32(0) - check_call(_LIB.MXDataIterGetBatchsize(self.handle, ctypes.byref(batch_size))) - return batch_size.value + index_size = ctypes.c_uint64(0) + index_data = ctypes.POINTER(ctypes.c_uint64)() + check_call(_LIB.MXDataIterGetIndex(self.handle, + ctypes.byref(index_data), + ctypes.byref(index_size))) + dbuffer = (ctypes.c_uint64* index_size.value).from_address(ctypes.addressof(index_data.contents)) + np_index = np.frombuffer(dbuffer, dtype=np.uint64) + return np_index.copy() def getpad(self): pad = ctypes.c_int(0) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 97823d78406f..5fbdae0f720c 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -838,19 +838,11 @@ int MXDataIterGetLabel(DataIterHandle handle, NDArrayHandle *out) { API_END(); } -int MXDataIterGetIndex(DataIterHandle handle, uint64_t *out_index) { +int MXDataIterGetIndex(DataIterHandle handle,uint64_t **out_index,uint64_t *out_size) { API_BEGIN(); const DataBatch& db = static_cast* >(handle)->Value(); - for (size_t i = 0; i < db.index.size(); ++i) { - out_index[i] = db.index[i]; - } - API_END(); -} - -MXNET_DLL int MXDataIterGetBatchsize(DataIterHandle handle, mx_uint* batch_size) { - API_BEGIN(); - const DataBatch& db = static_cast* >(handle)->Value(); - *batch_size = (mx_uint)db.index.size(); + *out_size = db.index.size(); + *out_index = const_cast(db.index.data()); API_END(); } diff --git a/src/io/iter_prefetcher.h b/src/io/iter_prefetcher.h index ea4c3b08cd9f..0474ed96157a 100644 --- a/src/io/iter_prefetcher.h +++ b/src/io/iter_prefetcher.h @@ -80,9 +80,9 @@ class PrefetcherIter : public IIterator { batch.data[i].FlatTo2D()); (*dptr)->num_batch_padd = batch.num_batch_padd; } - for (size_t i = 0; i < batch.batch_size; ++i) { - (*dptr)->index[i] = batch.inst_index[i]; - } + std::copy(batch.inst_index, + batch.inst_index + batch.batch_size, + (*dptr)->index.begin()); return true; }, [this]() { loader_->BeforeFirst(); }); From 07a0de7c642c40cc91f67b3618bb0166150f3659 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E4=BF=8A=E7=84=B6?= Date: Mon, 9 Nov 2015 17:48:19 +0800 Subject: [PATCH 08/23] fix lint --- python/mxnet/io.py | 11 +++++++---- src/c_api/c_api.cc | 2 +- src/io/iter_prefetcher.h | 5 +++-- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/python/mxnet/io.py b/python/mxnet/io.py index c2afea62ca87..57c7b77c0889 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -281,7 +281,8 @@ def reset(self): def next(self): if self._debug_skip_load and not self._debug_at_begin: - return DataBatch(data=[self.getdata()], label=[self.getlabel()], pad=self.getpad(), index=self.getindex()) + return DataBatch(data=[self.getdata()], label=[self.getlabel()], pad=self.getpad(), + index=self.getindex()) if self.first_batch is not None: batch = self.first_batch self.first_batch = None @@ -290,7 +291,8 @@ def next(self): next_res = ctypes.c_int(0) check_call(_LIB.MXDataIterNext(self.handle, ctypes.byref(next_res))) if next_res.value: - return DataBatch(data=[self.getdata()], label=[self.getlabel()], pad=self.getpad(), index=self.getindex()) + return DataBatch(data=[self.getdata()], label=[self.getlabel()], pad=self.getpad(), + index=self.getindex()) else: raise StopIteration @@ -314,10 +316,11 @@ def getlabel(self): def getindex(self): index_size = ctypes.c_uint64(0) index_data = ctypes.POINTER(ctypes.c_uint64)() - check_call(_LIB.MXDataIterGetIndex(self.handle, + check_call(_LIB.MXDataIterGetIndex(self.handle, ctypes.byref(index_data), ctypes.byref(index_size))) - dbuffer = (ctypes.c_uint64* index_size.value).from_address(ctypes.addressof(index_data.contents)) + address = ctypes.addressof(index_data.contents) + dbuffer = (ctypes.c_uint64* index_size.value).from_address(address) np_index = np.frombuffer(dbuffer, dtype=np.uint64) return np_index.copy() diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 5fbdae0f720c..f822f88da1b8 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -838,7 +838,7 @@ int MXDataIterGetLabel(DataIterHandle handle, NDArrayHandle *out) { API_END(); } -int MXDataIterGetIndex(DataIterHandle handle,uint64_t **out_index,uint64_t *out_size) { +int MXDataIterGetIndex(DataIterHandle handle, uint64_t **out_index, uint64_t *out_size) { API_BEGIN(); const DataBatch& db = static_cast* >(handle)->Value(); *out_size = db.index.size(); diff --git a/src/io/iter_prefetcher.h b/src/io/iter_prefetcher.h index 82975eedc758..61ab98360dd6 100644 --- a/src/io/iter_prefetcher.h +++ b/src/io/iter_prefetcher.h @@ -16,6 +16,7 @@ #include #include #include +#include #include "./inst_vector.h" namespace mxnet { @@ -80,8 +81,8 @@ class PrefetcherIter : public IIterator { batch.data[i].FlatTo2D()); (*dptr)->num_batch_padd = batch.num_batch_padd; } - std::copy(batch.inst_index, - batch.inst_index + batch.batch_size, + std::copy(batch.inst_index, + batch.inst_index + batch.batch_size, (*dptr)->index.begin()); return true; }, From f8187a6a0bc23d8cd46e20eca57ec0cac966c284 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E4=BF=8A=E7=84=B6?= Date: Mon, 9 Nov 2015 19:56:01 +0800 Subject: [PATCH 09/23] fix coredown --- src/io/iter_prefetcher.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/io/iter_prefetcher.h b/src/io/iter_prefetcher.h index 61ab98360dd6..afecc2e8ef61 100644 --- a/src/io/iter_prefetcher.h +++ b/src/io/iter_prefetcher.h @@ -68,7 +68,7 @@ class PrefetcherIter : public IIterator { *dptr = new DataBatch(); (*dptr)->num_batch_padd = batch.num_batch_padd; (*dptr)->data.resize(batch.data.size()); - (*dptr)->index.resize(batch.batch_size); + (*dptr)->index.resize(batch.batch_size()); for (size_t i = 0; i < batch.data.size(); ++i) { (*dptr)->data.at(i) = NDArray(batch.data[i].shape_, Context::CPU()); } From 33b0d03812b00716f4a84e1378380c2f6a44d362 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E4=BF=8A=E7=84=B6?= Date: Mon, 9 Nov 2015 20:34:37 +0800 Subject: [PATCH 10/23] fix bug with inst_index is null --- src/io/iter_prefetcher.h | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/io/iter_prefetcher.h b/src/io/iter_prefetcher.h index afecc2e8ef61..0765827df13a 100644 --- a/src/io/iter_prefetcher.h +++ b/src/io/iter_prefetcher.h @@ -68,7 +68,7 @@ class PrefetcherIter : public IIterator { *dptr = new DataBatch(); (*dptr)->num_batch_padd = batch.num_batch_padd; (*dptr)->data.resize(batch.data.size()); - (*dptr)->index.resize(batch.batch_size()); + (*dptr)->index.resize(batch.batch_size); for (size_t i = 0; i < batch.data.size(); ++i) { (*dptr)->data.at(i) = NDArray(batch.data[i].shape_, Context::CPU()); } @@ -81,9 +81,11 @@ class PrefetcherIter : public IIterator { batch.data[i].FlatTo2D()); (*dptr)->num_batch_padd = batch.num_batch_padd; } - std::copy(batch.inst_index, - batch.inst_index + batch.batch_size, - (*dptr)->index.begin()); + if (batch.inst_index) { + std::copy(batch.inst_index, + batch.inst_index + batch.batch_size, + (*dptr)->index.begin()); + } return true; }, [this]() { loader_->BeforeFirst(); }); From 652e3b2d7f5f0d7320005e54c448fb9f6a514dab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E4=BF=8A=E7=84=B6?= Date: Mon, 9 Nov 2015 20:52:39 +0800 Subject: [PATCH 11/23] fix io index --- python/mxnet/io.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/mxnet/io.py b/python/mxnet/io.py index 57c7b77c0889..1994767c1b4f 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -207,7 +207,8 @@ def iter_next(self): def next(self): if self.iter_next(): - return DataBatch(data=self.getdata(), label=self.getlabel(), pad=self.getpad()) + return DataBatch(data=self.getdata(), label=self.getlabel(), \ + pad=self.getpad(), index=None) else: raise StopIteration From a059d6aa5a1d8ac499178d50e96f15654ca40d76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E4=BF=8A=E7=84=B6?= Date: Tue, 10 Nov 2015 10:17:01 +0800 Subject: [PATCH 12/23] checkout Makefile --- Makefile | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index b7e5eff5b8af..a173f057db60 100644 --- a/Makefile +++ b/Makefile @@ -59,8 +59,8 @@ ifeq ($(USE_OPENMP), 1) endif ifeq ($(USE_CUDNN), 1) - CFLAGS += -DMSHADOW_USE_CUDNN=1 -I$(USE_CUDNN_PATH) - LDFLAGS += -lcudnn -L$(USE_CUDNN_PATH) + CFLAGS += -DMSHADOW_USE_CUDNN=1 + LDFLAGS += -lcudnn endif ifeq ($(USE_THREADED_ENGINE), 1) @@ -158,6 +158,15 @@ rcppexport: roxygen: Rscript -e "require(roxygen2); roxygen2::roxygenise(\"R-package\")" +rpkg: roxygen + mkdir -p R-package/inst + mkdir -p R-package/inst/libs + cp -rf lib/libmxnet.so R-package/inst/libs + mkdir -p R-package/inst/include + cp -rf include/* R-package/inst/include + cp -rf dmlc-core/include/* R-package/inst/include/ + R CMD build R-package + clean: $(RM) -r build lib bin *~ */*~ */*/*~ */*/*/*~ From 503443f9c0720b1a0f32c00385143543ca09e960 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E4=BF=8A=E7=84=B6?= Date: Tue, 3 Nov 2015 17:32:45 +0800 Subject: [PATCH 13/23] add enable image index --- include/mxnet/c_api.h | 10 +++++++++- include/mxnet/io.h | 2 ++ src/c_api/c_api.cc | 8 ++++++++ src/io/iter_prefetcher.h | 2 ++ 4 files changed, 21 insertions(+), 1 deletion(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 911e7e31c0f1..849fb3d3301e 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -708,7 +708,15 @@ MXNET_DLL int MXDataIterBeforeFirst(DataIterHandle handle); */ MXNET_DLL int MXDataIterGetData(DataIterHandle handle, NDArrayHandle *out); - +/*! + * \brief Get the image index by array + * \param handle the handle pointer to the data iterator + * \param index array size + * \return image index array + */ +MXNET_DLL int MXDataIterGetIndex(DataIterHandle handle, + mx_uint index_size, + uint64_t *out_index); /*! * \brief Get the padding number in current data batch * \param handle the handle pointer to the data iterator diff --git a/include/mxnet/io.h b/include/mxnet/io.h index a53e47e32d80..b4429a951920 100644 --- a/include/mxnet/io.h +++ b/include/mxnet/io.h @@ -60,6 +60,8 @@ struct DataInst { struct DataBatch { /*! \brief content of dense data, if this DataBatch is dense */ std::vector data; + /*! \brief index of image data */ + std::vector index; /*! \brief extra data to be fed to the network */ std::string extra_data; /*! \brief num of example padded to batch */ diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 8706ac1cc86c..085c87a0691e 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -837,6 +837,14 @@ int MXDataIterGetLabel(DataIterHandle handle, NDArrayHandle *out) { *out = pndarray; API_END(); } +int MXDataIterGetIndex(DataIterHandle handle,mx_uint index_size, uint64_t *out_index) { + API_BEGIN(); + const DataBatch& db = static_cast* >(handle)->Value(); + for (size_t i = 0; i < db.index.size(); ++i) { + out_index[i] = db.index[i]; + } + API_END(); +} int MXDataIterGetData(DataIterHandle handle, NDArrayHandle *out) { API_BEGIN(); diff --git a/src/io/iter_prefetcher.h b/src/io/iter_prefetcher.h index b3bbdb40c07e..9f96d552c659 100644 --- a/src/io/iter_prefetcher.h +++ b/src/io/iter_prefetcher.h @@ -68,8 +68,10 @@ class PrefetcherIter : public IIterator { *dptr = new DataBatch(); (*dptr)->num_batch_padd = batch.num_batch_padd; (*dptr)->data.resize(batch.data.size()); + (*dptr)->index.resize(batch.data.size()); for (size_t i = 0; i < batch.data.size(); ++i) { (*dptr)->data.at(i) = NDArray(batch.data[i].shape_, Context::CPU()); + (*dptr)->index.at(i) = batch.inst_index[i]; } } CHECK(batch.data.size() == (*dptr)->data.size()); From 775d97be282d0588a6e5fbc71e9ca45e286928d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E4=BF=8A=E7=84=B6?= Date: Fri, 6 Nov 2015 18:22:37 +0800 Subject: [PATCH 14/23] add index --- include/mxnet/c_api.h | 9 +++++++-- python/mxnet/io.py | 11 +++++++++++ src/c_api/c_api.cc | 10 +++++++++- src/io/iter_batchloader.h | 4 +++- src/io/iter_prefetcher.h | 12 +++++++++--- 5 files changed, 39 insertions(+), 7 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 849fb3d3301e..6047087f7062 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -711,12 +711,17 @@ MXNET_DLL int MXDataIterGetData(DataIterHandle handle, /*! * \brief Get the image index by array * \param handle the handle pointer to the data iterator - * \param index array size + * \param batch size * \return image index array */ MXNET_DLL int MXDataIterGetIndex(DataIterHandle handle, - mx_uint index_size, uint64_t *out_index); +/*! + * \brief Get current batch size of iter + */ +MXNET_DLL int MXDataIterGetBatchsize(DataIterHandle handle, + mx_uint* batch_size); + /*! * \brief Get the padding number in current data batch * \param handle the handle pointer to the data iterator diff --git a/python/mxnet/io.py b/python/mxnet/io.py index 2ca9f8712178..e9481d42f9df 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -303,6 +303,17 @@ def getlabel(self): check_call(_LIB.MXDataIterGetLabel(self.handle, ctypes.byref(hdl))) return NDArray(hdl, False) + def getindex(self): + batch_size = self.getbatchsize() + index = np.zeros((batch_size), dtype=np.uint64) + check_call(_LIB.MXDataIterGetIndex(self.handle, index.ctypes.data)) + return index + + def getbatchsize(self): + batch_size = ctypes.c_uint32(0) + check_call(_LIB.MXDataIterGetBatchsize(self.handle, ctypes.byref(batch_size))) + return batch_size.value + def getpad(self): pad = ctypes.c_int(0) check_call(_LIB.MXDataIterGetPadNum(self.handle, ctypes.byref(pad))) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 085c87a0691e..97823d78406f 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -837,7 +837,8 @@ int MXDataIterGetLabel(DataIterHandle handle, NDArrayHandle *out) { *out = pndarray; API_END(); } -int MXDataIterGetIndex(DataIterHandle handle,mx_uint index_size, uint64_t *out_index) { + +int MXDataIterGetIndex(DataIterHandle handle, uint64_t *out_index) { API_BEGIN(); const DataBatch& db = static_cast* >(handle)->Value(); for (size_t i = 0; i < db.index.size(); ++i) { @@ -846,6 +847,13 @@ int MXDataIterGetIndex(DataIterHandle handle,mx_uint index_size, uint64_t *out_ API_END(); } +MXNET_DLL int MXDataIterGetBatchsize(DataIterHandle handle, mx_uint* batch_size) { + API_BEGIN(); + const DataBatch& db = static_cast* >(handle)->Value(); + *batch_size = (mx_uint)db.index.size(); + API_END(); +} + int MXDataIterGetData(DataIterHandle handle, NDArrayHandle *out) { API_BEGIN(); const DataBatch& db = static_cast* >(handle)->Value(); diff --git a/src/io/iter_batchloader.h b/src/io/iter_batchloader.h index 1ec28d13d65a..f4924f4c0eeb 100644 --- a/src/io/iter_batchloader.h +++ b/src/io/iter_batchloader.h @@ -69,6 +69,7 @@ class BatchLoader : public IIterator { label_shape_ = TShape(label_shape_vec.begin(), label_shape_vec.end()); // Init space for out_ out_.inst_index = new unsigned[param_.batch_size]; + out_.batch_size = param_.batch_size; out_.data.clear(); data_holder_ = mshadow::NewTensor(data_shape_.get<4>(), 0.0f); label_holder_ = mshadow::NewTensor(label_shape_.get<2>(), 0.0f); @@ -88,6 +89,7 @@ class BatchLoader : public IIterator { } inline bool Next(void) { out_.num_batch_padd = 0; + out_.batch_size = param_.batch_size; this->head_ = 0; // if overflow from previous round, directly return false, until before first is called @@ -101,7 +103,7 @@ class BatchLoader : public IIterator { d.data[1].get()); mshadow::Copy(out_.data[0].get()[top], d.data[0].get()); - if (++ top >= param_.batch_size) { + if (++ top >= param_.batch_size) { return true; } } diff --git a/src/io/iter_prefetcher.h b/src/io/iter_prefetcher.h index 9f96d552c659..03beec42bc0b 100644 --- a/src/io/iter_prefetcher.h +++ b/src/io/iter_prefetcher.h @@ -62,16 +62,15 @@ class PrefetcherIter : public IIterator { iter_.Init([this](DataBatch **dptr) { if (!loader_->Next()) return false; const TBlobBatch& batch = loader_->Value(); - + if (*dptr == nullptr) { // allocate databatch *dptr = new DataBatch(); (*dptr)->num_batch_padd = batch.num_batch_padd; (*dptr)->data.resize(batch.data.size()); - (*dptr)->index.resize(batch.data.size()); + (*dptr)->index.resize(batch.batch_size); for (size_t i = 0; i < batch.data.size(); ++i) { (*dptr)->data.at(i) = NDArray(batch.data[i].shape_, Context::CPU()); - (*dptr)->index.at(i) = batch.inst_index[i]; } } CHECK(batch.data.size() == (*dptr)->data.size()); @@ -82,6 +81,13 @@ class PrefetcherIter : public IIterator { batch.data[i].FlatTo2D()); (*dptr)->num_batch_padd = batch.num_batch_padd; } + for (size_t i = 0; i < batch.batch_size; ++i) { + (*dptr)->index[i] = batch.inst_index[i]; +// printf("(%d,%d)", int((*dptr)->index[i]), +// int((*dptr)->data[1].data().FlatTo2D()[i][0])); + } + +//printf("\n-------------------\n"); return true; }, [this]() { loader_->BeforeFirst(); }); From 72d21ad3d17ef39ebd4d99d7f38cab92dc89bc74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E4=BF=8A=E7=84=B6?= Date: Fri, 6 Nov 2015 18:50:32 +0800 Subject: [PATCH 15/23] fix lint --- include/mxnet/c_api.h | 4 ++-- python/mxnet/io.py | 18 ++++++++++++++++++ src/io/iter_batchloader.h | 6 +++--- src/io/iter_prefetcher.h | 7 +------ 4 files changed, 24 insertions(+), 11 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 6047087f7062..ce5dbad8e591 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -715,12 +715,12 @@ MXNET_DLL int MXDataIterGetData(DataIterHandle handle, * \return image index array */ MXNET_DLL int MXDataIterGetIndex(DataIterHandle handle, - uint64_t *out_index); + uint64_t *out_index); /*! * \brief Get current batch size of iter */ MXNET_DLL int MXDataIterGetBatchsize(DataIterHandle handle, - mx_uint* batch_size); + mx_uint* batch_size); /*! * \brief Get the padding number in current data batch diff --git a/python/mxnet/io.py b/python/mxnet/io.py index e9481d42f9df..4bada82403b4 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -74,6 +74,24 @@ def getlabel(self): """ return self.getdata(-1) + def getindex(self): + """ + Retures + ------- + index : numpy.array + The index of current batch + """ + pass + + def getbatchsize(self): + """ + Retures + ------- + batch_size: int + The size of current batch + """ + pass + def getpad(self): """Get the number of padding examples in current batch. Returns diff --git a/src/io/iter_batchloader.h b/src/io/iter_batchloader.h index f4924f4c0eeb..57db5f2d7846 100644 --- a/src/io/iter_batchloader.h +++ b/src/io/iter_batchloader.h @@ -69,7 +69,7 @@ class BatchLoader : public IIterator { label_shape_ = TShape(label_shape_vec.begin(), label_shape_vec.end()); // Init space for out_ out_.inst_index = new unsigned[param_.batch_size]; - out_.batch_size = param_.batch_size; + out_.batch_size = param_.batch_size; out_.data.clear(); data_holder_ = mshadow::NewTensor(data_shape_.get<4>(), 0.0f); label_holder_ = mshadow::NewTensor(label_shape_.get<2>(), 0.0f); @@ -89,7 +89,7 @@ class BatchLoader : public IIterator { } inline bool Next(void) { out_.num_batch_padd = 0; - out_.batch_size = param_.batch_size; + out_.batch_size = param_.batch_size; this->head_ = 0; // if overflow from previous round, directly return false, until before first is called @@ -103,7 +103,7 @@ class BatchLoader : public IIterator { d.data[1].get()); mshadow::Copy(out_.data[0].get()[top], d.data[0].get()); - if (++ top >= param_.batch_size) { + if (++ top >= param_.batch_size) { return true; } } diff --git a/src/io/iter_prefetcher.h b/src/io/iter_prefetcher.h index 03beec42bc0b..ea4c3b08cd9f 100644 --- a/src/io/iter_prefetcher.h +++ b/src/io/iter_prefetcher.h @@ -62,13 +62,12 @@ class PrefetcherIter : public IIterator { iter_.Init([this](DataBatch **dptr) { if (!loader_->Next()) return false; const TBlobBatch& batch = loader_->Value(); - if (*dptr == nullptr) { // allocate databatch *dptr = new DataBatch(); (*dptr)->num_batch_padd = batch.num_batch_padd; (*dptr)->data.resize(batch.data.size()); - (*dptr)->index.resize(batch.batch_size); + (*dptr)->index.resize(batch.batch_size); for (size_t i = 0; i < batch.data.size(); ++i) { (*dptr)->data.at(i) = NDArray(batch.data[i].shape_, Context::CPU()); } @@ -83,11 +82,7 @@ class PrefetcherIter : public IIterator { } for (size_t i = 0; i < batch.batch_size; ++i) { (*dptr)->index[i] = batch.inst_index[i]; -// printf("(%d,%d)", int((*dptr)->index[i]), -// int((*dptr)->data[1].data().FlatTo2D()[i][0])); } - -//printf("\n-------------------\n"); return true; }, [this]() { loader_->BeforeFirst(); }); From 120cf4bfd0ac6c22248478e7d63aba20668e1f93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E4=BF=8A=E7=84=B6?= Date: Tue, 3 Nov 2015 17:32:45 +0800 Subject: [PATCH 16/23] add enable image index --- include/mxnet/c_api.h | 1 - src/c_api/c_api.cc | 8 ++++++++ src/io/iter_prefetcher.h | 1 + 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index ce5dbad8e591..7ef4d4463cbb 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -721,7 +721,6 @@ MXNET_DLL int MXDataIterGetIndex(DataIterHandle handle, */ MXNET_DLL int MXDataIterGetBatchsize(DataIterHandle handle, mx_uint* batch_size); - /*! * \brief Get the padding number in current data batch * \param handle the handle pointer to the data iterator diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 97823d78406f..744377d80a00 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -837,6 +837,14 @@ int MXDataIterGetLabel(DataIterHandle handle, NDArrayHandle *out) { *out = pndarray; API_END(); } +int MXDataIterGetIndex(DataIterHandle handle,mx_uint index_size, uint64_t *out_index) { + API_BEGIN(); + const DataBatch& db = static_cast* >(handle)->Value(); + for (size_t i = 0; i < db.index.size(); ++i) { + out_index[i] = db.index[i]; + } + API_END(); +} int MXDataIterGetIndex(DataIterHandle handle, uint64_t *out_index) { API_BEGIN(); diff --git a/src/io/iter_prefetcher.h b/src/io/iter_prefetcher.h index ea4c3b08cd9f..702577c7281c 100644 --- a/src/io/iter_prefetcher.h +++ b/src/io/iter_prefetcher.h @@ -70,6 +70,7 @@ class PrefetcherIter : public IIterator { (*dptr)->index.resize(batch.batch_size); for (size_t i = 0; i < batch.data.size(); ++i) { (*dptr)->data.at(i) = NDArray(batch.data[i].shape_, Context::CPU()); + (*dptr)->index.at(i) = batch.inst_index[i]; } } CHECK(batch.data.size() == (*dptr)->data.size()); From 7df73de3c7a87a4273f309892d5d7847bd272057 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E4=BF=8A=E7=84=B6?= Date: Fri, 6 Nov 2015 18:22:37 +0800 Subject: [PATCH 17/23] add index --- include/mxnet/c_api.h | 6 ++++++ python/mxnet/io.py | 1 - src/c_api/c_api.cc | 3 ++- src/io/iter_batchloader.h | 2 +- src/io/iter_prefetcher.h | 1 - 5 files changed, 9 insertions(+), 4 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 7ef4d4463cbb..62ad72d538fa 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -721,6 +721,12 @@ MXNET_DLL int MXDataIterGetIndex(DataIterHandle handle, */ MXNET_DLL int MXDataIterGetBatchsize(DataIterHandle handle, mx_uint* batch_size); +/*! + * \brief Get current batch size of iter + */ +MXNET_DLL int MXDataIterGetBatchsize(DataIterHandle handle, + mx_uint* batch_size); + /*! * \brief Get the padding number in current data batch * \param handle the handle pointer to the data iterator diff --git a/python/mxnet/io.py b/python/mxnet/io.py index 4bada82403b4..7922aaf210ba 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -295,7 +295,6 @@ def next(self): batch = self.first_batch self.first_batch = None return batch - self._debug_at_begin = False next_res = ctypes.c_int(0) check_call(_LIB.MXDataIterNext(self.handle, ctypes.byref(next_res))) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 744377d80a00..76c27952cbb5 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -837,7 +837,8 @@ int MXDataIterGetLabel(DataIterHandle handle, NDArrayHandle *out) { *out = pndarray; API_END(); } -int MXDataIterGetIndex(DataIterHandle handle,mx_uint index_size, uint64_t *out_index) { + +int MXDataIterGetIndex(DataIterHandle handle, uint64_t *out_index) { API_BEGIN(); const DataBatch& db = static_cast* >(handle)->Value(); for (size_t i = 0; i < db.index.size(); ++i) { diff --git a/src/io/iter_batchloader.h b/src/io/iter_batchloader.h index 57db5f2d7846..4ef650ccce46 100644 --- a/src/io/iter_batchloader.h +++ b/src/io/iter_batchloader.h @@ -103,7 +103,7 @@ class BatchLoader : public IIterator { d.data[1].get()); mshadow::Copy(out_.data[0].get()[top], d.data[0].get()); - if (++ top >= param_.batch_size) { + if (++ top >= param_.batch_size) { return true; } } diff --git a/src/io/iter_prefetcher.h b/src/io/iter_prefetcher.h index 702577c7281c..ea4c3b08cd9f 100644 --- a/src/io/iter_prefetcher.h +++ b/src/io/iter_prefetcher.h @@ -70,7 +70,6 @@ class PrefetcherIter : public IIterator { (*dptr)->index.resize(batch.batch_size); for (size_t i = 0; i < batch.data.size(); ++i) { (*dptr)->data.at(i) = NDArray(batch.data[i].shape_, Context::CPU()); - (*dptr)->index.at(i) = batch.inst_index[i]; } } CHECK(batch.data.size() == (*dptr)->data.size()); From 318ca2a6add8b30e7c302656639808b31759c417 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E4=BF=8A=E7=84=B6?= Date: Fri, 6 Nov 2015 18:50:32 +0800 Subject: [PATCH 18/23] fix lint --- include/mxnet/c_api.h | 2 +- src/io/iter_batchloader.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 62ad72d538fa..370c223b7130 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -725,7 +725,7 @@ MXNET_DLL int MXDataIterGetBatchsize(DataIterHandle handle, * \brief Get current batch size of iter */ MXNET_DLL int MXDataIterGetBatchsize(DataIterHandle handle, - mx_uint* batch_size); + mx_uint* batch_size); /*! * \brief Get the padding number in current data batch diff --git a/src/io/iter_batchloader.h b/src/io/iter_batchloader.h index 4ef650ccce46..57db5f2d7846 100644 --- a/src/io/iter_batchloader.h +++ b/src/io/iter_batchloader.h @@ -103,7 +103,7 @@ class BatchLoader : public IIterator { d.data[1].get()); mshadow::Copy(out_.data[0].get()[top], d.data[0].get()); - if (++ top >= param_.batch_size) { + if (++ top >= param_.batch_size) { return true; } } From c17affdd841024884c6ca17e7765d3933f82ef8f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E4=BF=8A=E7=84=B6?= Date: Mon, 9 Nov 2015 17:03:52 +0800 Subject: [PATCH 19/23] change c_api index interface --- include/mxnet/c_api.h | 17 +++-------------- python/mxnet/io.py | 34 +++++++++++++--------------------- src/c_api/c_api.cc | 23 +++-------------------- src/io/iter_prefetcher.h | 6 +++--- 4 files changed, 22 insertions(+), 58 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 370c223b7130..f82fb8ddcdbc 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -711,22 +711,11 @@ MXNET_DLL int MXDataIterGetData(DataIterHandle handle, /*! * \brief Get the image index by array * \param handle the handle pointer to the data iterator - * \param batch size - * \return image index array + * \return image index array and array size, index is const data */ MXNET_DLL int MXDataIterGetIndex(DataIterHandle handle, - uint64_t *out_index); -/*! - * \brief Get current batch size of iter - */ -MXNET_DLL int MXDataIterGetBatchsize(DataIterHandle handle, - mx_uint* batch_size); -/*! - * \brief Get current batch size of iter - */ -MXNET_DLL int MXDataIterGetBatchsize(DataIterHandle handle, - mx_uint* batch_size); - + uint64_t **out_index, + uint64_t *out_size); /*! * \brief Get the padding number in current data batch * \param handle the handle pointer to the data iterator diff --git a/python/mxnet/io.py b/python/mxnet/io.py index 7922aaf210ba..1eb97b55beb2 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -83,15 +83,6 @@ def getindex(self): """ pass - def getbatchsize(self): - """ - Retures - ------- - batch_size: int - The size of current batch - """ - pass - def getpad(self): """Get the number of padding examples in current batch. Returns @@ -102,7 +93,7 @@ def getpad(self): pass -DataBatch = namedtuple('DataBatch', ['data', 'label', 'pad']) +DataBatch = namedtuple('DataBatch', ['data', 'label', 'pad', 'index']) def _init_data(data, allow_empty, default_name): """Convert data into canonical form.""" @@ -290,7 +281,7 @@ def reset(self): def next(self): if self._debug_skip_load and not self._debug_at_begin: - return DataBatch(data=[self.getdata()], label=[self.getlabel()], pad=self.getpad()) + return DataBatch(data=[self.getdata()], label=[self.getlabel()], pad=self.getpad(), index=self.getindex()) if self.first_batch is not None: batch = self.first_batch self.first_batch = None @@ -299,7 +290,9 @@ def next(self): next_res = ctypes.c_int(0) check_call(_LIB.MXDataIterNext(self.handle, ctypes.byref(next_res))) if next_res.value: - return DataBatch(data=[self.getdata()], label=[self.getlabel()], pad=self.getpad()) + print 'label', self.getlabel().asnumpy() + print 'index', self.getindex() + return DataBatch(data=[self.getdata()], label=[self.getlabel()], pad=self.getpad(), index=self.getindex()) else: raise StopIteration @@ -321,15 +314,14 @@ def getlabel(self): return NDArray(hdl, False) def getindex(self): - batch_size = self.getbatchsize() - index = np.zeros((batch_size), dtype=np.uint64) - check_call(_LIB.MXDataIterGetIndex(self.handle, index.ctypes.data)) - return index - - def getbatchsize(self): - batch_size = ctypes.c_uint32(0) - check_call(_LIB.MXDataIterGetBatchsize(self.handle, ctypes.byref(batch_size))) - return batch_size.value + index_size = ctypes.c_uint64(0) + index_data = ctypes.POINTER(ctypes.c_uint64)() + check_call(_LIB.MXDataIterGetIndex(self.handle, + ctypes.byref(index_data), + ctypes.byref(index_size))) + dbuffer = (ctypes.c_uint64* index_size.value).from_address(ctypes.addressof(index_data.contents)) + np_index = np.frombuffer(dbuffer, dtype=np.uint64) + return np_index.copy() def getpad(self): pad = ctypes.c_int(0) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 76c27952cbb5..5fbdae0f720c 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -838,28 +838,11 @@ int MXDataIterGetLabel(DataIterHandle handle, NDArrayHandle *out) { API_END(); } -int MXDataIterGetIndex(DataIterHandle handle, uint64_t *out_index) { +int MXDataIterGetIndex(DataIterHandle handle,uint64_t **out_index,uint64_t *out_size) { API_BEGIN(); const DataBatch& db = static_cast* >(handle)->Value(); - for (size_t i = 0; i < db.index.size(); ++i) { - out_index[i] = db.index[i]; - } - API_END(); -} - -int MXDataIterGetIndex(DataIterHandle handle, uint64_t *out_index) { - API_BEGIN(); - const DataBatch& db = static_cast* >(handle)->Value(); - for (size_t i = 0; i < db.index.size(); ++i) { - out_index[i] = db.index[i]; - } - API_END(); -} - -MXNET_DLL int MXDataIterGetBatchsize(DataIterHandle handle, mx_uint* batch_size) { - API_BEGIN(); - const DataBatch& db = static_cast* >(handle)->Value(); - *batch_size = (mx_uint)db.index.size(); + *out_size = db.index.size(); + *out_index = const_cast(db.index.data()); API_END(); } diff --git a/src/io/iter_prefetcher.h b/src/io/iter_prefetcher.h index ea4c3b08cd9f..0474ed96157a 100644 --- a/src/io/iter_prefetcher.h +++ b/src/io/iter_prefetcher.h @@ -80,9 +80,9 @@ class PrefetcherIter : public IIterator { batch.data[i].FlatTo2D()); (*dptr)->num_batch_padd = batch.num_batch_padd; } - for (size_t i = 0; i < batch.batch_size; ++i) { - (*dptr)->index[i] = batch.inst_index[i]; - } + std::copy(batch.inst_index, + batch.inst_index + batch.batch_size, + (*dptr)->index.begin()); return true; }, [this]() { loader_->BeforeFirst(); }); From baf36844c78aee07f7d9d478cd145fd6334e9235 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E4=BF=8A=E7=84=B6?= Date: Mon, 9 Nov 2015 17:48:19 +0800 Subject: [PATCH 20/23] fix lint --- python/mxnet/io.py | 13 +++++++------ src/c_api/c_api.cc | 2 +- src/io/iter_prefetcher.h | 5 +++-- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/python/mxnet/io.py b/python/mxnet/io.py index 1eb97b55beb2..7a292c8210fe 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -281,7 +281,8 @@ def reset(self): def next(self): if self._debug_skip_load and not self._debug_at_begin: - return DataBatch(data=[self.getdata()], label=[self.getlabel()], pad=self.getpad(), index=self.getindex()) + return DataBatch(data=[self.getdata()], label=[self.getlabel()], pad=self.getpad(), + index=self.getindex()) if self.first_batch is not None: batch = self.first_batch self.first_batch = None @@ -290,9 +291,8 @@ def next(self): next_res = ctypes.c_int(0) check_call(_LIB.MXDataIterNext(self.handle, ctypes.byref(next_res))) if next_res.value: - print 'label', self.getlabel().asnumpy() - print 'index', self.getindex() - return DataBatch(data=[self.getdata()], label=[self.getlabel()], pad=self.getpad(), index=self.getindex()) + return DataBatch(data=[self.getdata()], label=[self.getlabel()], pad=self.getpad(), + index=self.getindex()) else: raise StopIteration @@ -316,10 +316,11 @@ def getlabel(self): def getindex(self): index_size = ctypes.c_uint64(0) index_data = ctypes.POINTER(ctypes.c_uint64)() - check_call(_LIB.MXDataIterGetIndex(self.handle, + check_call(_LIB.MXDataIterGetIndex(self.handle, ctypes.byref(index_data), ctypes.byref(index_size))) - dbuffer = (ctypes.c_uint64* index_size.value).from_address(ctypes.addressof(index_data.contents)) + address = ctypes.addressof(index_data.contents) + dbuffer = (ctypes.c_uint64* index_size.value).from_address(address) np_index = np.frombuffer(dbuffer, dtype=np.uint64) return np_index.copy() diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 5fbdae0f720c..f822f88da1b8 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -838,7 +838,7 @@ int MXDataIterGetLabel(DataIterHandle handle, NDArrayHandle *out) { API_END(); } -int MXDataIterGetIndex(DataIterHandle handle,uint64_t **out_index,uint64_t *out_size) { +int MXDataIterGetIndex(DataIterHandle handle, uint64_t **out_index, uint64_t *out_size) { API_BEGIN(); const DataBatch& db = static_cast* >(handle)->Value(); *out_size = db.index.size(); diff --git a/src/io/iter_prefetcher.h b/src/io/iter_prefetcher.h index 0474ed96157a..811ca1b15a21 100644 --- a/src/io/iter_prefetcher.h +++ b/src/io/iter_prefetcher.h @@ -16,6 +16,7 @@ #include #include #include +#include #include "./inst_vector.h" namespace mxnet { @@ -80,8 +81,8 @@ class PrefetcherIter : public IIterator { batch.data[i].FlatTo2D()); (*dptr)->num_batch_padd = batch.num_batch_padd; } - std::copy(batch.inst_index, - batch.inst_index + batch.batch_size, + std::copy(batch.inst_index, + batch.inst_index + batch.batch_size, (*dptr)->index.begin()); return true; }, From 909ca3ee230e213dcacccc96772e5c21eed02886 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E4=BF=8A=E7=84=B6?= Date: Mon, 9 Nov 2015 19:56:01 +0800 Subject: [PATCH 21/23] fix coredown --- src/io/iter_prefetcher.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/io/iter_prefetcher.h b/src/io/iter_prefetcher.h index 811ca1b15a21..00346a68a428 100644 --- a/src/io/iter_prefetcher.h +++ b/src/io/iter_prefetcher.h @@ -68,7 +68,7 @@ class PrefetcherIter : public IIterator { *dptr = new DataBatch(); (*dptr)->num_batch_padd = batch.num_batch_padd; (*dptr)->data.resize(batch.data.size()); - (*dptr)->index.resize(batch.batch_size); + (*dptr)->index.resize(batch.batch_size()); for (size_t i = 0; i < batch.data.size(); ++i) { (*dptr)->data.at(i) = NDArray(batch.data[i].shape_, Context::CPU()); } From 2fc16a35ed89093f5736f13648d083e2e04a9a0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E4=BF=8A=E7=84=B6?= Date: Mon, 9 Nov 2015 20:34:37 +0800 Subject: [PATCH 22/23] fix bug with inst_index is null --- src/io/iter_prefetcher.h | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/io/iter_prefetcher.h b/src/io/iter_prefetcher.h index 00346a68a428..0765827df13a 100644 --- a/src/io/iter_prefetcher.h +++ b/src/io/iter_prefetcher.h @@ -68,7 +68,7 @@ class PrefetcherIter : public IIterator { *dptr = new DataBatch(); (*dptr)->num_batch_padd = batch.num_batch_padd; (*dptr)->data.resize(batch.data.size()); - (*dptr)->index.resize(batch.batch_size()); + (*dptr)->index.resize(batch.batch_size); for (size_t i = 0; i < batch.data.size(); ++i) { (*dptr)->data.at(i) = NDArray(batch.data[i].shape_, Context::CPU()); } @@ -81,10 +81,12 @@ class PrefetcherIter : public IIterator { batch.data[i].FlatTo2D()); (*dptr)->num_batch_padd = batch.num_batch_padd; } - std::copy(batch.inst_index, - batch.inst_index + batch.batch_size, - (*dptr)->index.begin()); - return true; + if (batch.inst_index) { + std::copy(batch.inst_index, + batch.inst_index + batch.batch_size, + (*dptr)->index.begin()); + } + return true; }, [this]() { loader_->BeforeFirst(); }); } From 04a9a79de34990dfaaf19780858edcd5cf3248b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BD=95=E4=BF=8A=E7=84=B6?= Date: Mon, 9 Nov 2015 20:52:39 +0800 Subject: [PATCH 23/23] fix io index --- python/mxnet/io.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/mxnet/io.py b/python/mxnet/io.py index 7a292c8210fe..d6269d6752a4 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -207,7 +207,8 @@ def iter_next(self): def next(self): if self.iter_next(): - return DataBatch(data=self.getdata(), label=self.getlabel(), pad=self.getpad()) + return DataBatch(data=self.getdata(), label=self.getlabel(), \ + pad=self.getpad(), index=None) else: raise StopIteration