Skip to content

Commit

Permalink
set to need_grad other places
Browse files Browse the repository at this point in the history
  • Loading branch information
azai91 committed Nov 29, 2018
1 parent e5c5976 commit 895fd9c
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ void BatchNormForwardImpl(mshadow::Stream<cpu> *,
AccReal *mean = meanVector.dptr<AccReal>();
AccReal *var = varianceVector.dptr<AccReal>();

const bool is_train_and_not_global_stats = ctx.is_train && !param_.use_global_stats;
const bool is_train_and_not_global_stats = ctx.need_grad && !param_.use_global_stats;
const size_t channelCount = inputData.ChannelCount();
const size_t itemCountPerChannel = inputData.Size() / channelCount;

Expand Down Expand Up @@ -226,7 +226,7 @@ void BatchNormBackwardImpl(mshadow::Stream<cpu> *,
AccReal *gradWeightData = gradWeight.dptr<AccReal>();
AccReal *gradBiasData = gradBias.dptr<AccReal>();

const bool is_train_and_not_global_stats = ctx.is_train && !param_.use_global_stats;
const bool is_train_and_not_global_stats = ctx.need_grad && !param_.use_global_stats;

#pragma omp parallel for
for (int channel = 0; channel < static_cast<int>(channelCount); ++channel) {
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/mkldnn/mkldnn_deconvolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ void MKLDNNDeconvForward::SetDataHandle(const DeconvolutionParam& param,
auto data_mem = in_data.GetMKLDNNDataReorder(
fwd_pd.diff_dst_primitive_desc());
const mkldnn::memory *weight_mem;
if (ctx.is_train) {
if (ctx.need_grad) {
// TODO(zhengda) kvstore doesn't handle MKLDNN correctly. Let's reorder it
// to the default format for now.
if (weight.IsMKLDNNData())
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/mkldnn/mkldnn_lrn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ static MKLDNNLRNFwd &GetLRNFwd(const LRNParam& param,
OpHash> lrn_fwds;
#endif
auto kind_ =
ctx.is_train ? prop_kind::forward_training : prop_kind::forward_scoring;
ctx.need_grad ? prop_kind::forward_training : prop_kind::forward_scoring;

MKLDNNLRNSignature key(param);
key.AddSign(kind_);
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/mkldnn/mkldnn_pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ MKLDNNPoolingFwd &GetPoolingFwd(const PoolingParam &param,
void MKLDNNPoolingCompute(const OpContext &ctx, const PoolingParam &param,
const NDArray &in_data, const OpReqType req,
const NDArray &out_data, const NDArray *workspace) {
auto &fwd = GetPoolingFwd(param, ctx.is_train, in_data, out_data);
auto &fwd = GetPoolingFwd(param, ctx.need_grad, in_data, out_data);
fwd.SetNewMem(in_data, out_data, req, workspace);
fwd.Execute(out_data);
}
Expand Down

0 comments on commit 895fd9c

Please sign in to comment.