diff --git a/src/io/iter_batchloader.h b/src/io/iter_batchloader.h index 00f46ec4f721..1ec28d13d65a 100644 --- a/src/io/iter_batchloader.h +++ b/src/io/iter_batchloader.h @@ -56,8 +56,7 @@ class BatchLoader : public IIterator { std::vector > kwargs_left; // init batch param, it could have similar param with kwargs_left = param_.InitAllowUnknown(kwargs); - // init base iterator - base_->Init(kwargs); + // init object attributes std::vector data_shape_vec; data_shape_vec.push_back(param_.batch_size); for (size_t shape_dim = 0; shape_dim < param_.data_shape.ndim(); ++shape_dim) { @@ -75,6 +74,8 @@ class BatchLoader : public IIterator { label_holder_ = mshadow::NewTensor(label_shape_.get<2>(), 0.0f); out_.data.push_back(TBlob(data_holder_)); out_.data.push_back(TBlob(label_holder_)); + // init base iterator + base_->Init(kwargs); } inline void BeforeFirst(void) { if (param_.round_batch == 0 || num_overflow_ == 0) { diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index ed9ce358f24a..5ece7bac8023 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -85,6 +85,6 @@ def test_NDArrayIter(): assert(labelcount[i] == 100) if __name__ == "__main__": - test_NumpyIter() + #test_NDArrayIter() #test_MNISTIter() - #test_Cifar10Rec() + test_Cifar10Rec()