From aae950115a2bf478321f137966c0ab35d9846225 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 14 Sep 2015 16:55:53 -0600 Subject: [PATCH 1/2] Bugfix operator temp space need to change as stream changes --- src/c_api.cc | 1 + src/operator/batch_norm-inl.h | 4 ++-- src/operator/cudnn_convolution-inl.h | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/c_api.cc b/src/c_api.cc index 4a96d946d34f..7950a2f4a2bb 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -817,6 +817,7 @@ int MXDataIterBeforeFirst(DataIterHandle handle) { int MXDataIterNext(DataIterHandle handle, int *out) { API_BEGIN(); + Engine::Get()->WaitForAll(); *out = static_cast* >(handle)->Next(); API_END(); } diff --git a/src/operator/batch_norm-inl.h b/src/operator/batch_norm-inl.h index 93480e10357b..4c60ebf92ecb 100644 --- a/src/operator/batch_norm-inl.h +++ b/src/operator/batch_norm-inl.h @@ -170,10 +170,10 @@ class BatchNormOp : public Operator { // TODO(bing): use global memory allocator inline void Init(const OpContext &ctx, const mshadow::Shape<4> &dshape) { - if (is_init) return; - is_init = true; mshadow::Stream *s = ctx.get_stream(); tmp_.set_stream(s); + if (is_init) return; + is_init = true; tmp_.Resize(mshadow::Shape2(3, dshape[1])); } mshadow::TensorContainer tmp_; diff --git a/src/operator/cudnn_convolution-inl.h b/src/operator/cudnn_convolution-inl.h index 8b81818304e1..38397a931096 100644 --- a/src/operator/cudnn_convolution-inl.h +++ b/src/operator/cudnn_convolution-inl.h @@ -149,9 +149,9 @@ class CuDNNConvolutionOp : public Operator { size_t expected = param_.no_bias ? 2 : 3; CHECK_EQ(in_data.size(), expected); CHECK_EQ(out_data.size(), 1); + temp_.set_stream(s); if (!init_cudnn_) { init_cudnn_ = true; - temp_.set_stream(s); size_t workspace = static_cast(param_.workspace); size_t back_size = 0; size_t back_size_w = 0; From d7402c46192bfe1fccec871488fe28e9647f3637 Mon Sep 17 00:00:00 2001 From: tqchen Date: Mon, 14 Sep 2015 17:00:13 -0600 Subject: [PATCH 2/2] mark todo --- src/c_api.cc | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/c_api.cc b/src/c_api.cc index 7950a2f4a2bb..9b31b8e47641 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -180,6 +180,12 @@ inline int MXAPIGetFunctionRegInfo(const FunRegType *e, API_END(); } +int MXEngineWaitAll() { + API_BEGIN(); + Engine::Get()->WaitForAll(); + API_END(); +} + // NOTE: return value is added in API_END int MXNDArrayCreateNone(NDArrayHandle *out) { API_BEGIN(); @@ -323,12 +329,6 @@ int MXNDArrayListLoad(const char* fname, API_END(); } -int MXNDArrayWaitAll() { - API_BEGIN(); - Engine::Get()->WaitForAll(); - API_END(); -} - int MXNDArrayFree(NDArrayHandle handle) { API_BEGIN(); delete static_cast(handle); @@ -817,6 +817,8 @@ int MXDataIterBeforeFirst(DataIterHandle handle) { int MXDataIterNext(DataIterHandle handle, int *out) { API_BEGIN(); + // TODO(tianjun): remove this after having prefetcher by default. + // and call NArray.WaitForWrite instead. Engine::Get()->WaitForAll(); *out = static_cast* >(handle)->Next(); API_END();