diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 6254a1e18662..50a331912697 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -112,7 +112,7 @@ void BatchNormForwardImpl(mshadow::Stream *, AccReal *mean = meanVector.dptr(); AccReal *var = varianceVector.dptr(); - const bool is_train_and_not_global_stats = ctx.is_train && !param_.use_global_stats; + const bool is_train_and_not_global_stats = ctx.need_grad && !param_.use_global_stats; const size_t channelCount = inputData.ChannelCount(); const size_t itemCountPerChannel = inputData.Size() / channelCount; @@ -226,7 +226,7 @@ void BatchNormBackwardImpl(mshadow::Stream *, AccReal *gradWeightData = gradWeight.dptr(); AccReal *gradBiasData = gradBias.dptr(); - const bool is_train_and_not_global_stats = ctx.is_train && !param_.use_global_stats; + const bool is_train_and_not_global_stats = ctx.need_grad && !param_.use_global_stats; #pragma omp parallel for for (int channel = 0; channel < static_cast(channelCount); ++channel) { diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc index a6d6b24235c8..67316fb1efb3 100644 --- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc @@ -251,7 +251,7 @@ void MKLDNNDeconvForward::SetDataHandle(const DeconvolutionParam& param, auto data_mem = in_data.GetMKLDNNDataReorder( fwd_pd.diff_dst_primitive_desc()); const mkldnn::memory *weight_mem; - if (ctx.is_train) { + if (ctx.need_grad) { // TODO(zhengda) kvstore doesn't handle MKLDNN correctly. Let's reorder it // to the default format for now. if (weight.IsMKLDNNData()) diff --git a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h index 31b293a14c2c..325d0f6f3355 100644 --- a/src/operator/nn/mkldnn/mkldnn_lrn-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_lrn-inl.h @@ -180,7 +180,7 @@ static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param, OpHash> lrn_fwds; #endif auto kind_ = - ctx.is_train ? prop_kind::forward_training : prop_kind::forward_scoring; + ctx.need_grad ? prop_kind::forward_training : prop_kind::forward_scoring; MKLDNNLRNSignature key(param); key.AddSign(kind_); diff --git a/src/operator/nn/mkldnn/mkldnn_pooling.cc b/src/operator/nn/mkldnn/mkldnn_pooling.cc index f4d681ded78d..994fc2b1310c 100644 --- a/src/operator/nn/mkldnn/mkldnn_pooling.cc +++ b/src/operator/nn/mkldnn/mkldnn_pooling.cc @@ -269,7 +269,7 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam ¶m, void MKLDNNPoolingCompute(const OpContext &ctx, const PoolingParam ¶m, const NDArray &in_data, const OpReqType req, const NDArray &out_data, const NDArray *workspace) { - auto &fwd = GetPoolingFwd(param, ctx.is_train, in_data, out_data); + auto &fwd = GetPoolingFwd(param, ctx.need_grad, in_data, out_data); fwd.SetNewMem(in_data, out_data, req, workspace); fwd.Execute(out_data); }