Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[master] Fix issue with even number of channels in BatchNorm (#20907)
Browse files Browse the repository at this point in the history
* Fixed issue with batchnorm on even number of channels

* Formatted last commit
  • Loading branch information
piotrwolinski-intel committed Mar 2, 2022
1 parent 188d7b6 commit 8069a18
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 20 deletions.
6 changes: 2 additions & 4 deletions src/operator/contrib/batch_norm_relu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,11 @@ void BatchNormWithReLUComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 5U);
const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
bool fuse_relu = true;
if (SupportDNNLBNReLU(inputs[0], param)) {
CHECK_GT(outputs.size(), 3U);
DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
DNNL_REAL_TYPE_SWITCH(inputs[0].dtype(), DTYPE, {
DNNLBatchNormForward<DTYPE>(attrs, ctx, inputs, req, outputs, fuse_relu);
DNNLRun(DNNLBatchNormForward<DTYPE, /*fuse_relu*/ true>, attrs, ctx, inputs, req, outputs);
});
return;
}
Expand All @@ -167,11 +166,10 @@ void BatchNormWithReLUGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
bool fuse_relu = true;
if (SupportDNNLBNReLU(inputs[0], param)) {
CHECK_EQ(inputs.size(), 9U);
DNNL_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
DNNLBatchNormBackward<float>(attrs, ctx, inputs, req, outputs, fuse_relu);
DNNLRun(DNNLBatchNormBackward<float, /*fuse_relu*/ true>, attrs, ctx, inputs, req, outputs);
return;
}
LOG(FATAL) << "BatchNormWithReLU operator only supports oneDNN Backend.";
Expand Down
6 changes: 2 additions & 4 deletions src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -481,11 +481,10 @@ void BatchNormComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 5U);
const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
bool fuse_relu = false;
if (SupportDNNLBN(inputs[0], param)) {
DNNL_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
DNNL_REAL_TYPE_SWITCH(inputs[0].dtype(), DTYPE, {
DNNLBatchNormForward<DTYPE>(attrs, ctx, inputs, req, outputs, fuse_relu);
DNNLRun(DNNLBatchNormForward<DTYPE, /*fuse_relu*/ false>, attrs, ctx, inputs, req, outputs);
});
DNNL_OPCHECK_RUN(BatchNormCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
Expand All @@ -499,10 +498,9 @@ void BatchNormGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
bool fuse_relu = false;
if (SupportDNNLBN(inputs[0], param)) {
DNNL_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
DNNLBatchNormBackward<float>(attrs, ctx, inputs, req, outputs, fuse_relu);
DNNLRun(DNNLBatchNormBackward<float, /*fuse_relu*/ false>, attrs, ctx, inputs, req, outputs);
DNNL_OPCHECK_RUN(BatchNormGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
Expand Down
42 changes: 30 additions & 12 deletions src/operator/nn/dnnl/dnnl_batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,12 +145,12 @@ static DNNLBNForward& GetBNForward(const BatchNormParam& param,
}

template <typename DType>
void DNNLBatchNormForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs,
bool fuse_relu) {
void DNNLBatchNormForwardImpl(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs,
bool fuse_relu) {
const BatchNormParam& param = nnvm::get<BatchNormParam>(attrs.parsed);
std::vector<NDArray> in_data(inputs.begin(), inputs.begin() + batchnorm::kInMovingMean);

Expand Down Expand Up @@ -261,6 +261,15 @@ void DNNLBatchNormForward(const nnvm::NodeAttrs& attrs,
}
}

template <typename DType, bool fuse_relu>
void DNNLBatchNormForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
DNNLBatchNormForwardImpl<DType>(attrs, ctx, inputs, req, outputs, fuse_relu);
}

class DNNLBNBackward {
std::shared_ptr<dnnl::batch_normalization_backward> bwd;
const std::shared_ptr<dnnl::memory> weight_m;
Expand Down Expand Up @@ -317,12 +326,12 @@ static DNNLBNBackward& GetBNBackward(const BatchNormParam& param,
}

template <typename DType>
void DNNLBatchNormBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs,
bool fuse_relu) {
void DNNLBatchNormBackwardImpl(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs,
bool fuse_relu) {
if (fuse_relu) {
CHECK_EQ(inputs.size(), 9U);
} else {
Expand Down Expand Up @@ -481,6 +490,15 @@ void DNNLBatchNormBackward(const nnvm::NodeAttrs& attrs,
LOG(FATAL) << "oneDNN batch normalization backward: should not reach here ...";
}
}

template <typename DType, bool fuse_relu>
void DNNLBatchNormBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
DNNLBatchNormBackwardImpl<DType>(attrs, ctx, inputs, req, outputs, fuse_relu);
}
} // namespace op
} // namespace mxnet
#endif // MXNET_USE_ONEDNN
Expand Down

0 comments on commit 8069a18

Please sign in to comment.