diff --git a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h index cb3a7566c078..2d2bf2c64596 100644 --- a/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h @@ -214,7 +214,7 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam ¶m, const std::vector &out_data, const std::vector &aux_states) { TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]); - unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train && !param.use_global_stats); + unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train && !param.use_global_stats); const NDArray &data = in_data[batchnorm::kData]; auto &fwd = GetBNForward(param, ctx, data, flags);