From 9e4c800ce84ba9b35e8f9fe8cb1772883b958c76 Mon Sep 17 00:00:00 2001 From: ArmageddonKnight Date: Tue, 27 Aug 2019 14:02:21 -0400 Subject: [PATCH 1/4] Added (CuDNN)BatchNorm operator to the list of mirrored operators --- src/executor/graph_executor.cc | 2 -- src/operator/nn/cudnn/cudnn_batch_norm-inl.h | 17 ++++++++++++++++- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 7bdeac708003..36a3a5414b9e 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -359,8 +359,6 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol, if (type == "FullyConnected") return false; if (type == "Concat") return false; if (type == "SoftmaxOutput") return false; - if (type == "BatchNorm") return false; - if (type == "CuDNNBatchNorm") return false; return true; }; diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index 3fc91196708c..ff376163f02e 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -55,6 +55,7 @@ class CuDNNBatchNormOp { dtype_param_ = (dtype_ == CUDNN_DATA_HALF) ? kFloat32 : DataType::kFlag; CUDNN_CALL(cudnnCreateTensorDescriptor(&io_desc_)); CUDNN_CALL(cudnnCreateTensorDescriptor(&mean_desc_)); + internal_aux_states_lock_ = false; } void Init(const BatchNormParam ¶m) { @@ -122,6 +123,12 @@ class CuDNNBatchNormOp { Tensor save_inv_var = out_data[cudnnbatchnorm::kInvVar] .get_with_shape(Shape1(shape_[1]), s); + // If the lock on the auxiliary states is set, + // then this implies that the preceding call is also a `Forward()` call, + // which further indicates that we are in the backward mirroring mode, + // and therefore update to the auxiliary states is disabled. + // This is done by setting the `momentum` to `1` (or `factor` to `0`). + float factor = internal_aux_states_lock_ ? 0 : (1 - param_.momentum); CUDNN_CALL(cudnnBatchNormalizationForwardTraining(s->dnn_handle_, mode, &a, @@ -133,7 +140,7 @@ class CuDNNBatchNormOp { mean_desc_, gamma.dptr_, beta.dptr_, - 1 - param_.momentum, + factor, moving_mean.dptr_, moving_inv_var.dptr_, param_.eps, @@ -156,6 +163,10 @@ class CuDNNBatchNormOp { param_.eps)); } }) + // Set the lock on the auxiliary states. + // If the next call to the operator is a `Forward()` call, + // then `momentum` will be set to `1` and hence auxiliary states will not be updated. + internal_aux_states_lock_ = true; } void Backward(const OpContext &ctx, @@ -230,6 +241,9 @@ class CuDNNBatchNormOp { global_stats ? nullptr : save_inv_var.dptr_)); if (param_.fix_gamma) dgamma = 0.f; }) + // Release the lock on the auxiliary states, so that the next forward pass + // will be able to update the auxiliary states normally. + internal_aux_states_lock_ = false; } private: @@ -262,6 +276,7 @@ class CuDNNBatchNormOp { cudnnTensorDescriptor_t io_desc_, mean_desc_; mshadow::Shape<4> shape_; BatchNormParam param_; + bool internal_aux_states_lock_; }; #endif // defined(__CUDACC__) From 25551fb79f3c514cc64b5876ec32a20954fe931a Mon Sep 17 00:00:00 2001 From: JackieWu Date: Fri, 18 Oct 2019 17:02:15 +0800 Subject: [PATCH 2/4] ci From 8acdb805c0037a8b0c41cce6a1c823626f4d57c4 Mon Sep 17 00:00:00 2001 From: ArmageddonKnight Date: Fri, 25 Oct 2019 01:16:09 -0400 Subject: [PATCH 3/4] Enable the auxiliary state locking only in the backward mirroring mode --- src/operator/nn/cudnn/cudnn_batch_norm-inl.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index ff376163f02e..881d3d2247da 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -128,7 +128,8 @@ class CuDNNBatchNormOp { // which further indicates that we are in the backward mirroring mode, // and therefore update to the auxiliary states is disabled. // This is done by setting the `momentum` to `1` (or `factor` to `0`). - float factor = internal_aux_states_lock_ ? 0 : (1 - param_.momentum); + float factor = (dmlc::GetEnv("MXNET_BACKWARD_DO_MIRROR", 0) && internal_aux_states_lock_) ? + 0 : (1 - param_.momentum); CUDNN_CALL(cudnnBatchNormalizationForwardTraining(s->dnn_handle_, mode, &a, From e4005ac0f1299b8336529142c1e80a0b8b86e9b9 Mon Sep 17 00:00:00 2001 From: JackieWu Date: Tue, 12 Nov 2019 12:30:55 +0800 Subject: [PATCH 4/4] retrigger CI