Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #168 from sneakerkg/master
Browse files Browse the repository at this point in the history
remove python error when path incorrect
  • Loading branch information
tqchen committed Sep 27, 2015
2 parents ec73910 + 676fd47 commit 88a1bad
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
5 changes: 3 additions & 2 deletions src/io/iter_batchloader.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ class BatchLoader : public IIterator<TBlobBatch> {
std::vector<std::pair<std::string, std::string> > kwargs_left;
// init batch param, it could have similar param with
kwargs_left = param_.InitAllowUnknown(kwargs);
// init base iterator
base_->Init(kwargs);
// init object attributes
std::vector<size_t> data_shape_vec;
data_shape_vec.push_back(param_.batch_size);
for (size_t shape_dim = 0; shape_dim < param_.data_shape.ndim(); ++shape_dim) {
Expand All @@ -75,6 +74,8 @@ class BatchLoader : public IIterator<TBlobBatch> {
label_holder_ = mshadow::NewTensor<mshadow::cpu>(label_shape_.get<2>(), 0.0f);
out_.data.push_back(TBlob(data_holder_));
out_.data.push_back(TBlob(label_holder_));
// init base iterator
base_->Init(kwargs);
}
inline void BeforeFirst(void) {
if (param_.round_batch == 0 || num_overflow_ == 0) {
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,6 @@ def test_NDArrayIter():
assert(labelcount[i] == 100)

if __name__ == "__main__":
test_NumpyIter()
#test_NDArrayIter()
#test_MNISTIter()
#test_Cifar10Rec()
test_Cifar10Rec()

0 comments on commit 88a1bad

Please sign in to comment.