diff --git a/src/operator/contrib/batch_norm_relu.cc b/src/operator/contrib/batch_norm_relu.cc index e15bcbea1850..3d0cbb149317 100644 --- a/src/operator/contrib/batch_norm_relu.cc +++ b/src/operator/contrib/batch_norm_relu.cc @@ -149,12 +149,11 @@ void BatchNormWithReLUComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { CHECK_EQ(inputs.size(), 5U); const BatchNormParam& param = nnvm::get(attrs.parsed); - bool fuse_relu = true; if (SupportDNNLBNReLU(inputs[0], param)) { CHECK_GT(outputs.size(), 3U); DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs); DNNL_REAL_TYPE_SWITCH(inputs[0].dtype(), DTYPE, { - DNNLBatchNormForward(attrs, ctx, inputs, req, outputs, fuse_relu); + DNNLRun(DNNLBatchNormForward, attrs, ctx, inputs, req, outputs); }); return; } @@ -167,11 +166,10 @@ void BatchNormWithReLUGradComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { const BatchNormParam& param = nnvm::get(attrs.parsed); - bool fuse_relu = true; if (SupportDNNLBNReLU(inputs[0], param)) { CHECK_EQ(inputs.size(), 9U); DNNL_OPCHECK_INIT(true, outputs.size(), inputs, outputs); - DNNLBatchNormBackward(attrs, ctx, inputs, req, outputs, fuse_relu); + DNNLRun(DNNLBatchNormBackward, attrs, ctx, inputs, req, outputs); return; } LOG(FATAL) << "BatchNormWithReLU operator only supports oneDNN Backend."; diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index 04cc78a02c85..fa422f156c54 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -481,11 +481,10 @@ void BatchNormComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { CHECK_EQ(inputs.size(), 5U); const BatchNormParam& param = nnvm::get(attrs.parsed); - bool fuse_relu = false; if (SupportDNNLBN(inputs[0], param)) { DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs); DNNL_REAL_TYPE_SWITCH(inputs[0].dtype(), DTYPE, { - DNNLBatchNormForward(attrs, ctx, inputs, req, outputs, fuse_relu); + DNNLRun(DNNLBatchNormForward, attrs, ctx, inputs, req, outputs); }); DNNL_OPCHECK_RUN(BatchNormCompute, attrs, ctx, inputs, req, outputs); return; @@ -499,10 +498,9 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { const BatchNormParam& param = nnvm::get(attrs.parsed); - bool fuse_relu = false; if (SupportDNNLBN(inputs[0], param)) { DNNL_OPCHECK_INIT(true, outputs.size(), inputs, outputs); - DNNLBatchNormBackward(attrs, ctx, inputs, req, outputs, fuse_relu); + DNNLRun(DNNLBatchNormBackward, attrs, ctx, inputs, req, outputs); DNNL_OPCHECK_RUN(BatchNormGradCompute, attrs, ctx, inputs, req, outputs); return; } diff --git a/src/operator/nn/dnnl/dnnl_batch_norm-inl.h b/src/operator/nn/dnnl/dnnl_batch_norm-inl.h index ca644340a37f..97f21aef686b 100644 --- a/src/operator/nn/dnnl/dnnl_batch_norm-inl.h +++ b/src/operator/nn/dnnl/dnnl_batch_norm-inl.h @@ -145,12 +145,12 @@ static DNNLBNForward& GetBNForward(const BatchNormParam& param, } template -void DNNLBatchNormForward(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs, - bool fuse_relu) { +void DNNLBatchNormForwardImpl(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + bool fuse_relu) { const BatchNormParam& param = nnvm::get(attrs.parsed); std::vector in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean); @@ -261,6 +261,15 @@ void DNNLBatchNormForward(const nnvm::NodeAttrs& attrs, } } +template +void DNNLBatchNormForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + DNNLBatchNormForwardImpl(attrs, ctx, inputs, req, outputs, fuse_relu); +} + class DNNLBNBackward { std::shared_ptr bwd; const std::shared_ptr weight_m; @@ -317,12 +326,12 @@ static DNNLBNBackward& GetBNBackward(const BatchNormParam& param, } template -void DNNLBatchNormBackward(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs, - bool fuse_relu) { +void DNNLBatchNormBackwardImpl(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs, + bool fuse_relu) { if (fuse_relu) { CHECK_EQ(inputs.size(), 9U); } else { @@ -481,6 +490,15 @@ void DNNLBatchNormBackward(const nnvm::NodeAttrs& attrs, LOG(FATAL) << "oneDNN batch normalization backward: should not reach here ..."; } } + +template +void DNNLBatchNormBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + DNNLBatchNormBackwardImpl(attrs, ctx, inputs, req, outputs, fuse_relu); +} } // namespace op } // namespace mxnet #endif // MXNET_USE_ONEDNN