From 482d1906cf7564f1cca487e949716727852e8262 Mon Sep 17 00:00:00 2001 From: piotrw Date: Tue, 22 Feb 2022 09:38:59 +0000 Subject: [PATCH 1/2] Fixed issue with batchnorm on even number of channels --- src/operator/contrib/batch_norm_relu.cc | 6 ++---- src/operator/nn/batch_norm.cc | 6 ++---- src/operator/nn/dnnl/dnnl_batch_norm-inl.h | 22 ++++++++++++++++++++-- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/src/operator/contrib/batch_norm_relu.cc b/src/operator/contrib/batch_norm_relu.cc index e15bcbea1850..3d0cbb149317 100644 --- a/src/operator/contrib/batch_norm_relu.cc +++ b/src/operator/contrib/batch_norm_relu.cc @@ -149,12 +149,11 @@ void BatchNormWithReLUComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { CHECK_EQ(inputs.size(), 5U); const BatchNormParam& param = nnvm::get(attrs.parsed); - bool fuse_relu = true; if (SupportDNNLBNReLU(inputs[0], param)) { CHECK_GT(outputs.size(), 3U); DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs); DNNL_REAL_TYPE_SWITCH(inputs[0].dtype(), DTYPE, { - DNNLBatchNormForward(attrs, ctx, inputs, req, outputs, fuse_relu); + DNNLRun(DNNLBatchNormForward, attrs, ctx, inputs, req, outputs); }); return; } @@ -167,11 +166,10 @@ void BatchNormWithReLUGradComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { const BatchNormParam& param = nnvm::get(attrs.parsed); - bool fuse_relu = true; if (SupportDNNLBNReLU(inputs[0], param)) { CHECK_EQ(inputs.size(), 9U); DNNL_OPCHECK_INIT(true, outputs.size(), inputs, outputs); - DNNLBatchNormBackward(attrs, ctx, inputs, req, outputs, fuse_relu); + DNNLRun(DNNLBatchNormBackward, attrs, ctx, inputs, req, outputs); return; } LOG(FATAL) << "BatchNormWithReLU operator only supports oneDNN Backend."; diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 04cc78a02c85..fa422f156c54 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -481,11 +481,10 @@ void BatchNormComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { CHECK_EQ(inputs.size(), 5U); const BatchNormParam& param = nnvm::get(attrs.parsed); - bool fuse_relu = false; if (SupportDNNLBN(inputs[0], param)) { DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs); DNNL_REAL_TYPE_SWITCH(inputs[0].dtype(), DTYPE, { - DNNLBatchNormForward(attrs, ctx, inputs, req, outputs, fuse_relu); + DNNLRun(DNNLBatchNormForward, attrs, ctx, inputs, req, outputs); }); DNNL_OPCHECK_RUN(BatchNormCompute, attrs, ctx, inputs, req, outputs); return; @@ -499,10 +498,9 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { const BatchNormParam& param = nnvm::get(attrs.parsed); - bool fuse_relu = false; if (SupportDNNLBN(inputs[0], param)) { DNNL_OPCHECK_INIT(true, outputs.size(), inputs, outputs); - DNNLBatchNormBackward(attrs, ctx, inputs, req, outputs, fuse_relu); + DNNLRun(DNNLBatchNormBackward, attrs, ctx, inputs, req, outputs); DNNL_OPCHECK_RUN(BatchNormGradCompute, attrs, ctx, inputs, req, outputs); return; } diff --git a/src/operator/nn/dnnl/dnnl_batch_norm-inl.h b/src/operator/nn/dnnl/dnnl_batch_norm-inl.h index ca644340a37f..cb5d13f91e3e 100644 --- a/src/operator/nn/dnnl/dnnl_batch_norm-inl.h +++ b/src/operator/nn/dnnl/dnnl_batch_norm-inl.h @@ -145,7 +145,7 @@ static DNNLBNForward& GetBNForward(const BatchNormParam& param, } template -void DNNLBatchNormForward(const nnvm::NodeAttrs& attrs, +void DNNLBatchNormForwardImpl(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, const std::vector& req, @@ -261,6 +261,15 @@ void DNNLBatchNormForward(const nnvm::NodeAttrs& attrs, } } +template +void DNNLBatchNormForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + DNNLBatchNormForwardImpl(attrs, ctx, inputs, req, outputs, fuse_relu); +} + class DNNLBNBackward { std::shared_ptr bwd; const std::shared_ptr weight_m; @@ -317,7 +326,7 @@ static DNNLBNBackward& GetBNBackward(const BatchNormParam& param, } template -void DNNLBatchNormBackward(const nnvm::NodeAttrs& attrs, +void DNNLBatchNormBackwardImpl(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, const std::vector& req, @@ -481,6 +490,15 @@ void DNNLBatchNormBackward(const nnvm::NodeAttrs& attrs, LOG(FATAL) << "oneDNN batch normalization backward: should not reach here ..."; } } + +template +void DNNLBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + DNNLBatchNormBackwardImpl(attrs, ctx, inputs, req, outputs, + fuse_relu); +} } // namespace op } // namespace mxnet #endif // MXNET_USE_ONEDNN From 0510abc93117b94ec63e6fb389c49b3d44d99d2e Mon Sep 17 00:00:00 2001 From: piotrw Date: Tue, 22 Feb 2022 10:32:48 +0000 Subject: [PATCH 2/2] Formatted last commit --- src/operator/nn/dnnl/dnnl_batch_norm-inl.h | 40 +++++++++++----------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/src/operator/nn/dnnl/dnnl_batch_norm-inl.h b/src/operator/nn/dnnl/dnnl_batch_norm-inl.h index cb5d13f91e3e..97f21aef686b 100644 --- a/src/operator/nn/dnnl/dnnl_batch_norm-inl.h +++ b/src/operator/nn/dnnl/dnnl_batch_norm-inl.h @@ -146,11 +146,11 @@ static DNNLBNForward& GetBNForward(const BatchNormParam& param, template void DNNLBatchNormForwardImpl(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs, - bool fuse_relu) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + bool fuse_relu) { const BatchNormParam& param = nnvm::get(attrs.parsed); std::vector in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean); @@ -263,10 +263,10 @@ void DNNLBatchNormForwardImpl(const nnvm::NodeAttrs& attrs, template void DNNLBatchNormForward(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { DNNLBatchNormForwardImpl(attrs, ctx, inputs, req, outputs, fuse_relu); } @@ -327,11 +327,11 @@ static DNNLBNBackward& GetBNBackward(const BatchNormParam& param, template void DNNLBatchNormBackwardImpl(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs, - bool fuse_relu) { + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + bool fuse_relu) { if (fuse_relu) { CHECK_EQ(inputs.size(), 9U); } else { @@ -492,12 +492,12 @@ void DNNLBatchNormBackwardImpl(const nnvm::NodeAttrs& attrs, } template -void DNNLBatchNormBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { - DNNLBatchNormBackwardImpl(attrs, ctx, inputs, req, outputs, - fuse_relu); +void DNNLBatchNormBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + DNNLBatchNormBackwardImpl(attrs, ctx, inputs, req, outputs, fuse_relu); } } // namespace op } // namespace mxnet