Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[mkldnn-v1.0] Add MKL-DNN BN (#16199)
Browse files Browse the repository at this point in the history
* add mkldnn bn

* add static_cast to transform data type

* change mkldnn_args_map_t

* retrigger CI
  • Loading branch information
rongzha1 authored and TaoLv committed Sep 23, 2019
1 parent f930baa commit 0b8805a
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 165 deletions.
14 changes: 7 additions & 7 deletions src/operator/nn/batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#include <nnvm/op_attr_types.h>
#include "../elemwise_op_common.h"
#include "../operator_common.h"
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include "./mkldnn/mkldnn_batch_norm-inl.h"
#endif

Expand Down Expand Up @@ -379,7 +379,7 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,
return true;
}

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
static inline bool SupportMKLDNNBN(const NDArray &input, const BatchNormParam &param) {
mxnet::TShape shape = input.shape();
return SupportMKLDNN(input) && shape.ndim() == 4
Expand Down Expand Up @@ -454,7 +454,7 @@ static inline bool BatchNormStorageType(const nnvm::NodeAttrs &attrs,
const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);

bool dispatched = false;
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
if (!dispatched) {
dispatched = MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode,
in_attrs, out_attrs);
Expand Down Expand Up @@ -592,11 +592,11 @@ then set ``gamma`` to 1 and its gradient to 0.
.set_attr<nnvm::FInferType>("FInferType", BatchNormType)
.set_attr<FInferStorageType>("FInferStorageType", BatchNormStorageType)
.set_attr<FCompute>("FCompute<cpu>", BatchNormCompute<cpu>)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<FComputeEx>("FComputeEx<cpu>", BatchNormComputeExCPU)
#endif
.set_attr<nnvm::FGradient>("FGradient", BatchNormGrad)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
Expand All @@ -623,13 +623,13 @@ NNVM_REGISTER_OP(_backward_BatchNorm)
.set_num_outputs(3)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FInferStorageType>("FInferStorageType", BatchNormStorageType)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
#endif
.set_attr_parser(ParamParser<BatchNormParam>)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", BatchNormGradComputeExCPU)
#endif
Expand Down
Loading

0 comments on commit 0b8805a

Please sign in to comment.