Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
update support MKLDNN BN conditions (#15870)
Browse files Browse the repository at this point in the history
* update support MKLDNN BN conditions

* add bn test case for channel size = 8

* fix bn gradient with use_global_stats

* rm useless comments

* fix mkldnn bn output when use_global_stats

* fix lint

* retrigger ci

* retrigger ci again

* retrigger ci again 2

* retrigger ci again 3

* retrigger ci again 4
  • Loading branch information
ElaineBao authored and pengzhao-intel committed Aug 23, 2019
1 parent 73a692e commit d8c2d85
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 18 deletions.
10 changes: 4 additions & 6 deletions src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,6 @@ static inline bool SupportMKLDNNBN(const NDArray &input, const BatchNormParam &p
mxnet::TShape shape = input.shape();
return SupportMKLDNN(input) && shape.ndim() == 4
&& param.axis == mxnet::op::batchnorm::DEFAULT_AXIS
&& shape[param.axis] % 8 == 0
&& !mxnet::op::batchnorm::disable_mkl;
}

Expand All @@ -395,8 +394,8 @@ void BatchNormComputeExCPU(const nnvm::NodeAttrs &attrs,
const std::vector<NDArray> &outputs) {
CHECK_EQ(inputs.size(), 5U);
const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
// MKLDNN batchnorm only works well on the special MKLDNN layout.
if (SupportMKLDNNBN(inputs[0], param) && inputs[0].IsMKLDNNData()) {

if (SupportMKLDNNBN(inputs[0], param)) {
std::vector<NDArray> in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean);
std::vector<NDArray> aux_states(inputs.begin() + batchnorm::kInMovingMean, inputs.end());

Expand All @@ -419,9 +418,8 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs &attrs,
const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);

mxnet::TShape shape = inputs[0].shape();
// MKLDNN batchnorm only works well on the special MKLDNN layout.
if (SupportMKLDNNBN(inputs[0], param)
&& (inputs[3].IsMKLDNNData() || inputs[0].IsMKLDNNData())) {

if (SupportMKLDNNBN(inputs[0], param)) {
std::vector<NDArray> out_grad(1);
std::vector<NDArray> out_data(3);
std::vector<NDArray> in_data(3);
Expand Down
21 changes: 11 additions & 10 deletions src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,15 @@ using mkldnn::forward_inference;

inline static unsigned _GetFlags(const std::vector<NDArray> &in_data,
const std::vector<NDArray> &aux_states,
const BatchNormParam &param, bool is_train) {
const BatchNormParam &param, bool is_train_and_not_global_stats) {
unsigned flags = 0U;
if (in_data.size() == 3U) {
flags |= use_scale_shift;
}

// aux_states[0]: inMean
// aux_states[1]: inVariance
if (aux_states.size() == 2U && !is_train) {
if (aux_states.size() == 2U && !is_train_and_not_global_stats) {
flags |= use_global_stats;
}
return flags;
Expand Down Expand Up @@ -107,13 +107,13 @@ class MKLDNNBNForward {
std::shared_ptr<const mkldnn::memory> mean_m;
std::shared_ptr<const mkldnn::memory> var_m;
std::shared_ptr<mkldnn::batch_normalization_forward> fwd;
bool is_train;
bool is_train_and_not_global_stats;
t_bn_f_pdesc pd;

public:
MKLDNNBNForward(const t_bn_f_pdesc &_pd, bool is_train): pd(_pd) {
MKLDNNBNForward(const t_bn_f_pdesc &_pd, bool is_train_and_not_global_stats): pd(_pd) {
weight_m.reset(new mkldnn::memory(pd.weights_primitive_desc()));
this->is_train = is_train;
this->is_train_and_not_global_stats = is_train_and_not_global_stats;
}

const mkldnn::memory &GetWeight() const {
Expand Down Expand Up @@ -161,7 +161,7 @@ class MKLDNNBNForward {
}

if (fwd == nullptr) {
if (!is_train)
if (!is_train_and_not_global_stats)
fwd.reset(new mkldnn::batch_normalization_forward(
pd, *data_m, mkldnn::primitive::at(*mean_m),
mkldnn::primitive::at(*var_m), *weight_m, *out_m));
Expand Down Expand Up @@ -194,13 +194,14 @@ static MKLDNNBNForward &GetBNForward(const BatchNormParam& param,
#endif
MKLDNNBNSignature key(param);
key.AddSign(ctx.is_train);
key.AddSign(param.use_global_stats);
key.AddSign(in_data);

auto it = fwds.find(key);
if (it == fwds.end()) {
auto fwd_pd = _GetFwd(*in_data.GetMKLDNNData(), ctx.is_train,
(DType) param.eps, flags);
MKLDNNBNForward fwd(fwd_pd, ctx.is_train);
MKLDNNBNForward fwd(fwd_pd, ctx.is_train && !param.use_global_stats);
it = AddToCache(&fwds, key, fwd);
}
return it->second;
Expand All @@ -213,7 +214,7 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam &param,
const std::vector<NDArray> &out_data,
const std::vector<NDArray> &aux_states) {
TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]);
unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train);
unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train && !param.use_global_stats);
const NDArray &data = in_data[batchnorm::kData];

auto &fwd = GetBNForward<DType>(param, ctx, data, flags);
Expand Down Expand Up @@ -253,7 +254,7 @@ void MKLDNNBatchNormForward(const OpContext &ctx, const BatchNormParam &param,
}
}

if (!ctx.is_train) {
if (!ctx.is_train || param.use_global_stats) {
DType* omean = out_data[batchnorm::kMean].data().dptr<DType>();
DType* ovar = out_data[batchnorm::kVar].data().dptr<DType>();
DType* inmean = aux_states[batchnorm::kMovingMean].data().dptr<DType>();
Expand Down Expand Up @@ -378,7 +379,7 @@ void MKLDNNBatchNormBackward(const OpContext &ctx, const BatchNormParam &param,
CHECK_EQ(in_data.size(), 3U);
CHECK_EQ(out_data.size(), 3U);
CHECK_EQ(in_grad.size(), 3U);
unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train);
unsigned flags = _GetFlags(in_data, aux_states, param, ctx.is_train && !param.use_global_stats);

const NDArray &data = in_data[batchnorm::kData];
const NDArray &diff = out_grad[batchnorm::kOut];
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1626,7 +1626,7 @@ def test_bilinear_upsampling():
@with_seed()
def test_batchnorm_training():
def check_batchnorm_training(stype):
for shape in [(2, 3), (2, 3, 2, 2)]:
for shape in [(2, 3), (2, 3, 2, 2), (2, 8, 2, 2)]:
data_tmp = np.random.normal(-0.1, 0.1, size=shape)
s = shape[1],
gamma = np.ones(s)
Expand Down Expand Up @@ -1821,7 +1821,7 @@ def _test_batchnorm_impl(op, shape, axis, cudnn_off, output_mean_var):
bn_beta.grad.asnumpy(), db.asnumpy(), atol=atol, rtol=rtol)

for op in [mx.nd.BatchNorm, mx.nd.contrib.SyncBatchNorm]:
for shape in [(24, 2), (24, 3, 4), (24, 4, 4, 4), (24, 5, 6, 4, 4)]:
for shape in [(24, 2), (24, 3, 4), (24, 4, 4, 4), (24, 8, 4, 4), (24, 5, 6, 4, 4)]:
for axis in range(len(shape)):
for cudnn_off in [False, True]:
for output_mean_var in [False, True]:
Expand Down

0 comments on commit d8c2d85

Please sign in to comment.