diff --git a/src/c_api.cc b/src/c_api.cc index 4a96d946d34f..9b31b8e47641 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -180,6 +180,12 @@ inline int MXAPIGetFunctionRegInfo(const FunRegType *e, API_END(); } +int MXEngineWaitAll() { + API_BEGIN(); + Engine::Get()->WaitForAll(); + API_END(); +} + // NOTE: return value is added in API_END int MXNDArrayCreateNone(NDArrayHandle *out) { API_BEGIN(); @@ -323,12 +329,6 @@ int MXNDArrayListLoad(const char* fname, API_END(); } -int MXNDArrayWaitAll() { - API_BEGIN(); - Engine::Get()->WaitForAll(); - API_END(); -} - int MXNDArrayFree(NDArrayHandle handle) { API_BEGIN(); delete static_cast(handle); @@ -817,6 +817,9 @@ int MXDataIterBeforeFirst(DataIterHandle handle) { int MXDataIterNext(DataIterHandle handle, int *out) { API_BEGIN(); + // TODO(tianjun): remove this after having prefetcher by default. + // and call NArray.WaitForWrite instead. + Engine::Get()->WaitForAll(); *out = static_cast* >(handle)->Next(); API_END(); } diff --git a/src/operator/batch_norm-inl.h b/src/operator/batch_norm-inl.h index 93480e10357b..4c60ebf92ecb 100644 --- a/src/operator/batch_norm-inl.h +++ b/src/operator/batch_norm-inl.h @@ -170,10 +170,10 @@ class BatchNormOp : public Operator { // TODO(bing): use global memory allocator inline void Init(const OpContext &ctx, const mshadow::Shape<4> &dshape) { - if (is_init) return; - is_init = true; mshadow::Stream *s = ctx.get_stream(); tmp_.set_stream(s); + if (is_init) return; + is_init = true; tmp_.Resize(mshadow::Shape2(3, dshape[1])); } mshadow::TensorContainer tmp_; diff --git a/src/operator/cudnn_convolution-inl.h b/src/operator/cudnn_convolution-inl.h index 8b81818304e1..38397a931096 100644 --- a/src/operator/cudnn_convolution-inl.h +++ b/src/operator/cudnn_convolution-inl.h @@ -149,9 +149,9 @@ class CuDNNConvolutionOp : public Operator { size_t expected = param_.no_bias ? 2 : 3; CHECK_EQ(in_data.size(), expected); CHECK_EQ(out_data.size(), 1); + temp_.set_stream(s); if (!init_cudnn_) { init_cudnn_ = true; - temp_.set_stream(s); size_t workspace = static_cast(param_.workspace); size_t back_size = 0; size_t back_size_w = 0;