Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix mkldnn bn output when use_global_stats
Browse files Browse the repository at this point in the history
  • Loading branch information
ElaineBao committed Aug 21, 2019
1 parent 3086347 commit 201d869
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ using mkldnn::forward_inference;

inline static unsigned _GetFlags(const std::vector<NDArray> &in_data,
const std::vector<NDArray> &aux_states,
const BatchNormParam &param, bool is_train) {
const BatchNormParam &param, bool is_train_and_not_global_stats) {
unsigned flags = 0U;
if (in_data.size() == 3U) {
flags |= use_scale_shift;
}

// aux_states[0]: inMean
// aux_states[1]: inVariance
if (aux_states.size() == 2U && !is_train) {
if (aux_states.size() == 2U && !is_train_and_not_global_stats) {
flags |= use_global_stats;
}
return flags;
Expand Down Expand Up @@ -107,13 +107,13 @@ class MKLDNNBNForward {
std::shared_ptr<const mkldnn::memory> mean_m;
std::shared_ptr<const mkldnn::memory> var_m;
std::shared_ptr<mkldnn::batch_normalization_forward> fwd;
bool is_train;
bool is_train_and_not_global_stats;
t_bn_f_pdesc pd;

public:
MKLDNNBNForward(const t_bn_f_pdesc &_pd, bool is_train): pd(_pd) {
MKLDNNBNForward(const t_bn_f_pdesc &_pd, bool is_train_and_not_global_stats): pd(_pd) {
weight_m.reset(new mkldnn::memory(pd.weights_primitive_desc()));
this->is_train = is_train;
this->is_train_and_not_global_stats = is_train_and_not_global_stats;
}

const mkldnn::memory &GetWeight() const {
Expand Down Expand Up @@ -161,7 +161,7 @@ class MKLDNNBNForward {
}

if (fwd == nullptr) {
if (!is_train)
if (!is_train_and_not_global_stats)
fwd.reset(new mkldnn::batch_normalization_forward(
pd, *data_m, mkldnn::primitive::at(*mean_m),
mkldnn::primitive::at(*var_m), *weight_m, *out_m));
Expand Down Expand Up @@ -194,13 +194,14 @@ static MKLDNNBNForward &GetBNForward(const BatchNormParam& param,
#endif
MKLDNNBNSignature key(param);
key.AddSign(ctx.is_train);
key.AddSign(param.use_global_stats);
key.AddSign(in_data);

auto it = fwds.find(key);
if (it == fwds.end()) {
auto fwd_pd = _GetFwd(*in_data.GetMKLDNNData(), ctx.is_train,
(DType) param.eps, flags);
MKLDNNBNForward fwd(fwd_pd, ctx.is_train);
MKLDNNBNForward fwd(fwd_pd, ctx.is_train && !param.use_global_stats);
it = AddToCache(&fwds, key, fwd);
}
return it->second;
Expand All @@ -213,7 +214,7 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam &param,
const std::vector<NDArray> &out_data,
const std::vector<NDArray> &aux_states) {
TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]);
unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train);
unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train && !param.use_global_stats);
const NDArray &data = in_data[batchnorm::kData];

auto &fwd = GetBNForward<DType>(param, ctx, data, flags);
Expand Down Expand Up @@ -378,7 +379,7 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam &param,
CHECK_EQ(in_data.size(), 3U);
CHECK_EQ(out_data.size(), 3U);
CHECK_EQ(in_grad.size(), 3U);
unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train);
unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train && !param.use_global_stats);

const NDArray &data = in_data[batchnorm::kData];
const NDArray &diff = out_grad[batchnorm::kOut];
Expand Down

0 comments on commit 201d869

Please sign in to comment.