Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1421] Added (CuDNN)BatchNorm operator to the list of mirrored operators #16022

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,6 @@ nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol,
if (type == "FullyConnected") return false;
if (type == "Concat") return false;
if (type == "SoftmaxOutput") return false;
if (type == "BatchNorm") return false;
if (type == "CuDNNBatchNorm") return false;
return true;
};

Expand Down
18 changes: 17 additions & 1 deletion src/operator/nn/cudnn/cudnn_batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class CuDNNBatchNormOp {
dtype_param_ = (dtype_ == CUDNN_DATA_HALF) ? kFloat32 : DataType<DType>::kFlag;
CUDNN_CALL(cudnnCreateTensorDescriptor(&io_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&mean_desc_));
internal_aux_states_lock_ = false;
}

void Init(const BatchNormParam &param) {
Expand Down Expand Up @@ -122,6 +123,13 @@ class CuDNNBatchNormOp {
Tensor<gpu, 1, DTypeParam> save_inv_var =
out_data[cudnnbatchnorm::kInvVar]
.get_with_shape<gpu, 1, DTypeParam>(Shape1(shape_[1]), s);
// If the lock on the auxiliary states is set,
// then this implies that the preceding call is also a `Forward()` call,
// which further indicates that we are in the backward mirroring mode,
// and therefore update to the auxiliary states is disabled.
// This is done by setting the `momentum` to `1` (or `factor` to `0`).
float factor = (dmlc::GetEnv("MXNET_BACKWARD_DO_MIRROR", 0) && internal_aux_states_lock_) ?
0 : (1 - param_.momentum);
CUDNN_CALL(cudnnBatchNormalizationForwardTraining(s->dnn_handle_,
mode,
&a,
Expand All @@ -133,7 +141,7 @@ class CuDNNBatchNormOp {
mean_desc_,
gamma.dptr_,
beta.dptr_,
1 - param_.momentum,
factor,
moving_mean.dptr_,
moving_inv_var.dptr_,
param_.eps,
Expand All @@ -156,6 +164,10 @@ class CuDNNBatchNormOp {
param_.eps));
}
})
// Set the lock on the auxiliary states.
// If the next call to the operator is a `Forward()` call,
// then `momentum` will be set to `1` and hence auxiliary states will not be updated.
internal_aux_states_lock_ = true;
}

void Backward(const OpContext &ctx,
Expand Down Expand Up @@ -230,6 +242,9 @@ class CuDNNBatchNormOp {
global_stats ? nullptr : save_inv_var.dptr_));
if (param_.fix_gamma) dgamma = 0.f;
})
// Release the lock on the auxiliary states, so that the next forward pass
// will be able to update the auxiliary states normally.
internal_aux_states_lock_ = false;
}

private:
Expand Down Expand Up @@ -262,6 +277,7 @@ class CuDNNBatchNormOp {
cudnnTensorDescriptor_t io_desc_, mean_desc_;
mshadow::Shape<4> shape_;
BatchNormParam param_;
bool internal_aux_states_lock_;
};
#endif // defined(__CUDACC__)

Expand Down