diff --git a/src/executor/attach_op_execs_pass.cc b/src/executor/attach_op_execs_pass.cc index c011c1d9ce03..0e415ef5112a 100644 --- a/src/executor/attach_op_execs_pass.cc +++ b/src/executor/attach_op_execs_pass.cc @@ -159,6 +159,9 @@ class StatefulComputeExExecutor : public OpExecutor { op_ctx.run_ctx = rctx; #if MXNET_USE_MKLDNN == 1 InvalidateOutputs(out_array, req); + CreateDefaultInputs(in_array, &in_array_fallback); + fcompute_(state_, op_ctx, in_array_fallback, req, out_array); + return; #endif fcompute_(state_, op_ctx, in_array, req, out_array); } @@ -226,6 +229,13 @@ class FComputeExExecutor : public OpExecutor { op_ctx.run_ctx = rctx; #if MXNET_USE_MKLDNN == 1 InvalidateOutputs(out_array, req); + // TODO(alex): (MXNET-847) Remove this fallback feature after subgraph implemented + const auto is_mkldnn = Op::GetAttr("TIsMKLDNN"); + if (!is_mkldnn.get(attrs_.op, false)) { + CreateDefaultInputs(in_array, &in_array_fallback); + fcompute_(attrs_, op_ctx, in_array_fallback, req, out_array); + return; + } #endif fcompute_(attrs_, op_ctx, in_array, req, out_array); } diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h index cd1db0ac1944..52f7c790c77e 100644 --- a/src/executor/exec_pass.h +++ b/src/executor/exec_pass.h @@ -86,6 +86,10 @@ class OpExecutor { virtual OpStatePtr state() const { return OpStatePtr(); } + + // TODO(alexzai): (MXNET-856) Remove instance member after subgraph feature added + protected: + std::vector in_array_fallback; }; /*! diff --git a/src/operator/nn/activation.cc b/src/operator/nn/activation.cc index b8c2045fba12..ba44ebd4ed4d 100644 --- a/src/operator/nn/activation.cc +++ b/src/operator/nn/activation.cc @@ -155,6 +155,7 @@ The following activation functions are supported: }) .set_attr("FCompute", ActivationCompute) #if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", ActivationComputeExCPU) #endif .set_attr("FGradient", ActivationGrad{"_backward_Activation"}) @@ -184,6 +185,7 @@ NNVM_REGISTER_OP(_backward_Activation) #endif .set_attr_parser(ParamParser) #if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", ActivationGradComputeExCPU) #endif .set_attr("FCompute", ActivationGradCompute); diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc index b15f84e107e0..4ea494d64e47 100644 --- a/src/operator/nn/batch_norm.cc +++ b/src/operator/nn/batch_norm.cc @@ -601,6 +601,7 @@ the sparse tensors will fallback. #endif .set_attr("FGradient", BatchNormGrad) #if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) @@ -633,6 +634,7 @@ NNVM_REGISTER_OP(_backward_BatchNorm) #endif .set_attr_parser(ParamParser) #if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", BatchNormGradComputeExCPU) #endif .set_attr("FCompute", BatchNormGradCompute); diff --git a/src/operator/nn/concat.cc b/src/operator/nn/concat.cc index 9df459e9224d..ac8a814ce70f 100644 --- a/src/operator/nn/concat.cc +++ b/src/operator/nn/concat.cc @@ -367,6 +367,7 @@ Example:: .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("TIsMKLDNN", true) #endif CONCAT_FORWARD_ATTRS .set_attr("FInferShape", ConcatShape) @@ -387,6 +388,7 @@ NNVM_REGISTER_OP(_backward_Concat) .set_attr("TIsBackward", true) .set_attr("FInferStorageType", BackwardConcatStorageType) #if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", ConcatGradComputeExCPU) #endif .set_attr("FCompute", ConcatGradCompute); diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc index 8f25cf0dcbb1..d5abe629123b 100644 --- a/src/operator/nn/convolution.cc +++ b/src/operator/nn/convolution.cc @@ -484,6 +484,7 @@ There are other options to tune the performance. #endif .set_attr("FCompute", ConvolutionCompute) #if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", ConvolutionComputeExCPU) #endif .set_attr("FGradient", ConvolutionGrad{"_backward_Convolution"}) @@ -509,6 +510,7 @@ NNVM_REGISTER_OP(_backward_Convolution) }) .set_attr_parser(ConvolutionParamParser) #if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", ConvolutionGradComputeExCPU) #endif .set_attr("FCompute", ConvolutionGradCompute); diff --git a/src/operator/nn/deconvolution.cc b/src/operator/nn/deconvolution.cc index a4be1a0c56a0..1ab391d92b04 100644 --- a/src/operator/nn/deconvolution.cc +++ b/src/operator/nn/deconvolution.cc @@ -413,6 +413,7 @@ NNVM_REGISTER_OP(Deconvolution) }) .set_attr("FCompute", DeconvolutionCompute) #if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", DeconvolutionComputeExCPU) #endif .set_attr("FGradient", DeconvolutionGrad{"_backward_Deconvolution"}) @@ -436,6 +437,7 @@ NNVM_REGISTER_OP(_backward_Deconvolution) }) .set_attr_parser(DeconvolutionParamParser) #if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", DeconvolutionGradComputeExCPU) #endif .set_attr("FCompute", DeconvolutionGradCompute); diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc index eb881d29abd1..d8a32f0ae963 100644 --- a/src/operator/nn/fully_connected.cc +++ b/src/operator/nn/fully_connected.cc @@ -290,6 +290,7 @@ If ``no_bias`` is set to be true, then the ``bias`` term is ignored. return std::vector{"output"}; }) #if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) @@ -322,6 +323,7 @@ NNVM_REGISTER_OP(_backward_FullyConnected) .set_attr("FInferStorageType", BackwardFCStorageType) .set_attr_parser(ParamParser) #if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", FullyConnectedGradComputeExCPU) #endif .set_attr("FCompute", FullyConnectedGradCompute); diff --git a/src/operator/nn/lrn.cc b/src/operator/nn/lrn.cc index 587cf930920e..a428eb1e4faf 100644 --- a/src/operator/nn/lrn.cc +++ b/src/operator/nn/lrn.cc @@ -180,6 +180,7 @@ number of kernels in the layer. }) .set_attr("FCompute", LRNCompute) #if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", LRNComputeExCPU) #endif .set_attr("FGradient", LRNGrad{"_backward_LRN"}) @@ -194,6 +195,7 @@ NNVM_REGISTER_OP(_backward_LRN) #endif .set_attr("TIsBackward", true) #if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", LRNGradComputeExCPU) // Native compute requires norm while MKLDNN does not so cannot be compared in debug mode .set_attr("TExcludeMKLDNNDebug", true) diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 273afcd32dc7..6eb90f845d37 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -356,6 +356,18 @@ static inline void InvalidateOutputs(const std::vector &arrs, } } +// TODO(alexzai): (MXNET-856) Remove helper function after subgraph feature added +static inline void CreateDefaultInputs(const std::vector &arrs, + std::vector *out_arrs) { + out_arrs->clear(); + for (size_t i = 0; i < arrs.size(); ++i) { + if (arrs[i].IsMKLDNNData()) + out_arrs->push_back(arrs[i].Reorder2Default()); + else + out_arrs->push_back(arrs[i]); + } +} + const mkldnn::memory *GetWeights(const NDArray &arr, const mkldnn::memory::primitive_desc &target_pd, int num_groups); diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc index 2d118142bc79..c133b63623af 100644 --- a/src/operator/nn/pooling.cc +++ b/src/operator/nn/pooling.cc @@ -395,6 +395,7 @@ For each window ``X``, the mathematical expression for Lp pooling is: .set_attr("FInferShape", PoolingShape) .set_attr("FCompute", PoolingCompute) #if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", PoolingComputeExCPU) #endif .set_attr("FGradient", @@ -424,6 +425,7 @@ NNVM_REGISTER_OP(_backward_Pooling) #endif .set_attr_parser(PoolingParamParser) #if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", PoolingGradComputeExCPU) #endif .set_attr("FCompute", PoolingGradCompute); diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc index 88b7b5fc473e..81e775cac526 100644 --- a/src/operator/nn/softmax.cc +++ b/src/operator/nn/softmax.cc @@ -98,6 +98,7 @@ Example:: }) .set_attr("FCompute", SoftmaxCompute) #if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", SoftmaxComputeExCPU) .set_attr("FInferStorageType", SoftmaxStorageType) #endif diff --git a/src/operator/tensor/elemwise_sum.cc b/src/operator/tensor/elemwise_sum.cc index 9630988165ce..1666537e2860 100644 --- a/src/operator/tensor/elemwise_sum.cc +++ b/src/operator/tensor/elemwise_sum.cc @@ -179,6 +179,9 @@ The storage type of ``add_n`` output depends on storage types of inputs [](const NodeAttrs& attrs) { return std::vector{ResourceRequest::kTempSpace}; }) +#if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) +#endif .set_attr("FInferShape", ElementWiseSumShape) .set_attr("FInferType", ElementWiseSumType) .set_attr("FInferStorageType", ElementWiseSumForwardInferStorageType) diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index e09a6cccddbf..eb070a411279 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -299,7 +299,11 @@ class UnaryOp : public OpBase { } break; case kWriteInplace: +// cannot check if ptrs are the same for MKLDNN because we may have +// created copies of input when reordering. WriteInPlace will still write to original array +#if MXNET_USE_MKLDNN == 0 CHECK_EQ(inputs[0].dptr_, outputs[0].dptr_); +#endif break; case kNullOp: break; diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index f7f21f9076a6..c3e9c2dc91d0 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -206,6 +206,7 @@ MXNET_OPERATOR_REGISTER_UNARY(_copy) .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) +.set_attr("TIsMKLDNN", true) #endif .set_attr("FInplaceIdentity", [](const NodeAttrs& attrs){ @@ -225,6 +226,7 @@ NNVM_REGISTER_OP(_backward_copy) .set_attr("FCompute", UnaryOp::IdentityCompute) .set_attr("FComputeEx", CopyEx) #if MXNET_USE_MKLDNN == 1 +.set_attr("TIsMKLDNN", true) .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py index ba4cf3f0116a..e597d0f5fc58 100644 --- a/tests/python/mkl/test_mkldnn.py +++ b/tests/python/mkl/test_mkldnn.py @@ -381,6 +381,50 @@ def check_fullyconnected_training(stype): for stype in stypes: check_fullyconnected_training(stype) +@with_seed() +def test_non_mkldnn_fcomputeex(): + # test special case where MKLDNN formatted NDArray feeds into non-mkldnn fcomputeex operator + # conv is example where MKLDNN NDArray is created from regular NDArrays + # CustomOps is example of non-mkldnn fcomputeex operator + + @mx.operator.register("custom") + class CustomProp(mx.operator.CustomOpProp): + def __int__(self): + super(CustomProp, self).__init__(need_top_grad=False) + + def list_arguments(self): + return ['data'] + + def list_outputs(self): + return ['output'] + + def infer_shape(self, in_shape): + data_shape = in_shape[0] + output_shape = in_shape[0] + return [data_shape], [output_shape], [] + + def infer_type(self, in_type): + dtype = in_type[0] + return [dtype], [dtype], [] + + def create_operator(self, ctx, shapes, dtypes): + return Custom() + + + class Custom(mx.operator.CustomOp): + def forward(self, is_train, req, in_data, out_data, aux): + print(in_data[0]) + self.assign(out_data[0], req[0], in_data[0]) + + def backward(self, req, out_grad, in_data, out_data, in_grad, aux): + self.assign(in_grad[0], req[0], out_grad) + + data = mx.symbol.Variable('data') + conv = mx.sym.Convolution(data=data, kernel=(5, 5), pad=(1, 1), stride=(1,1), num_filter=8, name="conv", no_bias=True) + custom = mx.symbol.Custom(name='custom', data=conv, op_type='custom') + exec1 = custom.bind(mx.cpu(), args={'data': mx.nd.ones([10,3,96,96]), 'conv_weight': mx.nd.ones([8,3,5,5])}) + exec1.forward()[0].wait_to_read() + if __name__ == '__main__': install.test_mkldnn_install()