diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 7bdeac708003..36a3a5414b9e 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -359,8 +359,6 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol, if (type == "FullyConnected") return false; if (type == "Concat") return false; if (type == "SoftmaxOutput") return false; - if (type == "BatchNorm") return false; - if (type == "CuDNNBatchNorm") return false; return true; }; diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index 3fc91196708c..881d3d2247da 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -55,6 +55,7 @@ class CuDNNBatchNormOp { dtype_param_ = (dtype_ == CUDNN_DATA_HALF) ? kFloat32 : DataType::kFlag; CUDNN_CALL(cudnnCreateTensorDescriptor(&io_desc_)); CUDNN_CALL(cudnnCreateTensorDescriptor(&mean_desc_)); + internal_aux_states_lock_ = false; } void Init(const BatchNormParam ¶m) { @@ -122,6 +123,13 @@ class CuDNNBatchNormOp { Tensor save_inv_var = out_data[cudnnbatchnorm::kInvVar] .get_with_shape(Shape1(shape_[1]), s); + // If the lock on the auxiliary states is set, + // then this implies that the preceding call is also a `Forward()` call, + // which further indicates that we are in the backward mirroring mode, + // and therefore update to the auxiliary states is disabled. + // This is done by setting the `momentum` to `1` (or `factor` to `0`). + float factor = (dmlc::GetEnv("MXNET_BACKWARD_DO_MIRROR", 0) && internal_aux_states_lock_) ? + 0 : (1 - param_.momentum); CUDNN_CALL(cudnnBatchNormalizationForwardTraining(s->dnn_handle_, mode, &a, @@ -133,7 +141,7 @@ class CuDNNBatchNormOp { mean_desc_, gamma.dptr_, beta.dptr_, - 1 - param_.momentum, + factor, moving_mean.dptr_, moving_inv_var.dptr_, param_.eps, @@ -156,6 +164,10 @@ class CuDNNBatchNormOp { param_.eps)); } }) + // Set the lock on the auxiliary states. + // If the next call to the operator is a `Forward()` call, + // then `momentum` will be set to `1` and hence auxiliary states will not be updated. + internal_aux_states_lock_ = true; } void Backward(const OpContext &ctx, @@ -230,6 +242,9 @@ class CuDNNBatchNormOp { global_stats ? nullptr : save_inv_var.dptr_)); if (param_.fix_gamma) dgamma = 0.f; }) + // Release the lock on the auxiliary states, so that the next forward pass + // will be able to update the auxiliary states normally. + internal_aux_states_lock_ = false; } private: @@ -262,6 +277,7 @@ class CuDNNBatchNormOp { cudnnTensorDescriptor_t io_desc_, mean_desc_; mshadow::Shape<4> shape_; BatchNormParam param_; + bool internal_aux_states_lock_; }; #endif // defined(__CUDACC__)