From beb8505b6d72ae9d05917b67c2f95851aa1bf391 Mon Sep 17 00:00:00 2001 From: Da zheng Date: Tue, 7 Nov 2017 01:35:26 +0000 Subject: [PATCH] Handle kAddTo in MKLDNN operators. --- src/operator/nn/mkldnn/mkldnn_base-inl.h | 38 ++++++++++++ src/operator/nn/mkldnn/mkldnn_convolution.cc | 55 ++++++----------- .../nn/mkldnn/mkldnn_deconvolution.cc | 61 ++++++------------- .../nn/mkldnn/mkldnn_fully_connected.cc | 61 ++++++------------- src/operator/nn/mkldnn/mkldnn_relu-inl.h | 8 +-- src/operator/nn/mkldnn/mkldnn_sum.cc | 52 ++++++++++++++++ 6 files changed, 149 insertions(+), 126 deletions(-) create mode 100644 src/operator/nn/mkldnn/mkldnn_sum.cc diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 733980ef54e8..6d6671c181a4 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -151,6 +151,44 @@ inline static mkldnn_mem_ptr CreateMKLDNNMem(const mkldnn::memory::primitive_des return ret; } +enum OutDataOp { + Noop, + CopyBack, + AddBack, +}; + +typedef std::pair mkldnn_output_t; + +static inline mkldnn_output_t CreateMKLDNNMem(const NDArray &arr, + const mkldnn::memory::primitive_desc &desc, OpReqType req) { + if (kAddTo == req) + return mkldnn_output_t(OutDataOp::AddBack, CreateMKLDNNMem(desc)); + else { + mkldnn_mem_ptr mem = const_cast(arr).CreateMKLDNNData(desc); + if (mem == nullptr) + return mkldnn_output_t(OutDataOp::CopyBack, CreateMKLDNNMem(desc)); + else + return mkldnn_output_t(OutDataOp::Noop, mem); + } +} + +namespace op { +void Sum(const mkldnn::memory &arr1, const mkldnn::memory &arr2, + const mkldnn::memory &out); +} + +static inline void CommitOutput(const NDArray &arr, const mkldnn_output_t &res) { + if (res.first == CopyBack) + const_cast(arr).CopyFrom(*res.second); + else if (res.first == AddBack) { + // TODO I might need to reorder. + mkldnn_mem_const_ptr mem = arr.GetMKLDNNData(res.second->get_primitive_desc()); + mkldnn_mem_ptr out = CreateMKLDNNMem(res.second->get_primitive_desc()); + op::Sum(*res.second, *mem, *out); + const_cast(arr).CopyFrom(*out); + } +} + inline static mkldnn_mem_const_ptr GetWeights(const NDArray &arr, const mkldnn::memory::primitive_desc &target_pd, int num_groups) { mkldnn_mem_const_ptr mem; diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc index 28ee1874d6d8..61134d0d8021 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc @@ -182,18 +182,18 @@ void MKLDNNConvolution_Forward(const nnvm::NodeAttrs& attrs, const OpContext &ct auto engine = CpuEngine::Instance().get_engine(); auto weight_mem = GetWeights(in_data[conv::kWeight], fwd_pd.weights_primitive_desc(), param.num_group); - - auto out_mem = const_cast(out_data[conv::kOut]).CreateMKLDNNData( - fwd_pd.dst_primitive_desc()); + auto out_mem = CreateMKLDNNMem(out_data[conv::kOut], + fwd_pd.dst_primitive_desc(), req[conv::kOut]); if (param.no_bias) { MKLDNNStream::Instance().RegisterPrim(mkldnn::convolution_forward(fwd_pd, - *data_mem, *weight_mem, *out_mem)); + *data_mem, *weight_mem, *out_mem.second)); } else { auto bias_mem = in_data[conv::kBias].GetMKLDNNDataReorder(fwd_pd.bias_primitive_desc()); MKLDNNStream::Instance().RegisterPrim(mkldnn::convolution_forward(fwd_pd, - *data_mem, *weight_mem, *bias_mem, *out_mem)); + *data_mem, *weight_mem, *bias_mem, *out_mem.second)); } + CommitOutput(out_data[conv::kOut], out_mem); MKLDNNStream::Instance().Submit(); } @@ -216,17 +216,11 @@ void MKLDNNConvolution_Backward(const nnvm::NodeAttrs& attrs, const OpContext &c bwdData_pd.diff_dst_primitive_desc()); auto weight_mem = GetWeights(inputs[conv::kWeight + 1], bwdData_pd.weights_primitive_desc(), param.num_group); - auto in_grad_mem = const_cast(in_grad[conv::kData]).CreateMKLDNNData( - bwdData_pd.diff_src_primitive_desc()); - bool copy_back = false; - if (in_grad_mem == nullptr) { - in_grad_mem = CreateMKLDNNMem(bwdData_pd.diff_src_primitive_desc()); - copy_back = true; - } + auto in_grad_mem = CreateMKLDNNMem(in_grad[conv::kData], + bwdData_pd.diff_src_primitive_desc(), req[conv::kData]); MKLDNNStream::Instance().RegisterPrim(mkldnn::convolution_backward_data(bwdData_pd, - *out_grad_mem, *weight_mem, *in_grad_mem)); - if (copy_back) - const_cast(in_grad[conv::kData]).CopyFrom(*in_grad_mem); + *out_grad_mem, *weight_mem, *in_grad_mem.second)); + CommitOutput(in_grad[conv::kData], in_grad_mem); } if (req[conv::kWeight]) { mkldnn::convolution_backward_weights::primitive_desc bwdWeights_pd @@ -236,32 +230,21 @@ void MKLDNNConvolution_Backward(const nnvm::NodeAttrs& attrs, const OpContext &c bwdWeights_pd.diff_dst_primitive_desc()); auto data_mem = inputs[conv::kData + 1].GetMKLDNNDataReorder( bwdWeights_pd.src_primitive_desc()); - auto in_grad_weight = const_cast(in_grad[conv::kWeight]).CreateMKLDNNData( - bwdWeights_pd.diff_weights_primitive_desc()); - bool copy_back_weight = false; - bool copy_back_bias = false; - if (in_grad_weight == nullptr) { - in_grad_weight = CreateMKLDNNMem(bwdWeights_pd.diff_weights_primitive_desc()); - copy_back_weight = true; - } - mkldnn_mem_const_ptr in_grad_bias; + auto in_grad_weight = CreateMKLDNNMem(in_grad[conv::kWeight], + bwdWeights_pd.diff_weights_primitive_desc(), req[conv::kWeight]); + mkldnn_output_t in_grad_bias; if (param.no_bias) { MKLDNNStream::Instance().RegisterPrim(mkldnn::convolution_backward_weights( - bwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight)); + bwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second)); } else { - in_grad_bias = const_cast(in_grad[conv::kBias]).CreateMKLDNNData( - bwdWeights_pd.diff_bias_primitive_desc()); - if (in_grad_bias == nullptr) { - in_grad_bias = CreateMKLDNNMem(bwdWeights_pd.diff_bias_primitive_desc()); - copy_back_bias = true; - } + in_grad_bias = CreateMKLDNNMem(in_grad[conv::kBias], + bwdWeights_pd.diff_bias_primitive_desc(), req[conv::kBias]); MKLDNNStream::Instance().RegisterPrim(mkldnn::convolution_backward_weights( - bwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight, *in_grad_bias)); + bwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second, + *in_grad_bias.second)); } - if (copy_back_weight) - const_cast(in_grad[conv::kWeight]).CopyFrom(*in_grad_weight); - if (copy_back_bias) - const_cast(in_grad[conv::kBias]).CopyFrom(*in_grad_bias); + CommitOutput(in_grad[conv::kWeight], in_grad_weight); + CommitOutput(in_grad[conv::kBias], in_grad_bias); } MKLDNNStream::Instance().Submit(); } diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc index f8675b637f62..8a8566432706 100644 --- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc @@ -179,18 +179,12 @@ void MKLDNNDeconvolution_Forward(const nnvm::NodeAttrs& attrs, const OpContext & deconvFwd_pd.diff_src_primitive_desc()); auto weight_mem = GetWeights(in_data[deconv::kWeight], deconvFwd_pd.weights_primitive_desc(), param.num_group); - auto out_mem = const_cast(out_data[deconv::kOut]).CreateMKLDNNData( - deconvFwd_pd.diff_dst_primitive_desc()); - bool copy_back = false; - if (out_mem == nullptr) { - out_mem = CreateMKLDNNMem(deconvFwd_pd.diff_dst_primitive_desc()); - copy_back = true; - } + auto out_mem = CreateMKLDNNMem(out_data[deconv::kOut], + deconvFwd_pd.diff_dst_primitive_desc(), req[deconv::kOut]); MKLDNNStream::Instance().RegisterPrim(mkldnn::convolution_backward_data( - deconvFwd_pd, *data_mem, *weight_mem, *out_mem)); - if (copy_back) - const_cast(out_data[deconv::kOut]).CopyFrom(*out_mem); + deconvFwd_pd, *data_mem, *weight_mem, *out_mem.second)); + CommitOutput(out_data[deconv::kOut], out_mem); MKLDNNStream::Instance().Submit(); if (!param.no_bias) { // add bias, broadcast bias to dim 1: channel @@ -209,7 +203,6 @@ void MKLDNNDeconvolution_Backward(const nnvm::NodeAttrs& attrs, const OpContext const std::vector& outputs) { const std::vector &in_grad = outputs; const DeconvolutionParam& param = nnvm::get(attrs.parsed); - CHECK_NE(req[deconv::kWeight], kWriteInplace) << "cannot write weight inplace"; mkldnn::convolution_forward::primitive_desc bwdData_pd = GetDeconvBwdData( param, inputs[deconv::kData + 1], inputs[deconv::kWeight + 1], nullptr, @@ -219,17 +212,11 @@ void MKLDNNDeconvolution_Backward(const nnvm::NodeAttrs& attrs, const OpContext bwdData_pd.src_primitive_desc()); auto weight_mem = GetWeights(inputs[deconv::kWeight + 1], bwdData_pd.weights_primitive_desc(), param.num_group); - auto in_grad_mem = const_cast(in_grad[deconv::kData]).CreateMKLDNNData( - bwdData_pd.dst_primitive_desc()); - bool copy_back = false; - if (in_grad_mem == nullptr) { - in_grad_mem = CreateMKLDNNMem(bwdData_pd.dst_primitive_desc()); - copy_back = true; - } + auto in_grad_mem = CreateMKLDNNMem(in_grad[deconv::kData], + bwdData_pd.dst_primitive_desc(), req[deconv::kData]); MKLDNNStream::Instance().RegisterPrim(mkldnn::convolution_forward(bwdData_pd, - *out_grad_mem, *weight_mem, *in_grad_mem)); - if (copy_back) - const_cast(in_grad[deconv::kData]).CopyFrom(*in_grad_mem); + *out_grad_mem, *weight_mem, *in_grad_mem.second)); + CommitOutput(in_grad[deconv::kData], in_grad_mem); } if (req[deconv::kWeight]) { mkldnn::convolution_backward_weights::primitive_desc bwdWeights_pd @@ -237,37 +224,25 @@ void MKLDNNDeconvolution_Backward(const nnvm::NodeAttrs& attrs, const OpContext inputs[deconv::kWeight + 1], param.no_bias ? nullptr : &inputs[deconv::kWeight + 1], inputs[deconv::kOut], bwdData_pd); - CHECK_NE(req[deconv::kWeight], kAddTo); auto out_grad_mem = inputs[deconv::kOut].GetMKLDNNDataReorder( bwdWeights_pd.diff_dst_primitive_desc()); auto data_mem = inputs[deconv::kData + 1].GetMKLDNNDataReorder( bwdWeights_pd.src_primitive_desc()); - auto in_grad_weight = const_cast(in_grad[deconv::kWeight]).CreateMKLDNNData( - bwdWeights_pd.diff_weights_primitive_desc()); - bool copy_back_weight = false; - bool copy_back_bias = false; - if (in_grad_weight == nullptr) { - in_grad_weight = CreateMKLDNNMem(bwdWeights_pd.diff_weights_primitive_desc()); - copy_back_weight = true; - } - mkldnn_mem_const_ptr in_grad_bias; + auto in_grad_weight = CreateMKLDNNMem(in_grad[deconv::kWeight], + bwdWeights_pd.diff_weights_primitive_desc(), req[deconv::kWeight]); + mkldnn_output_t in_grad_bias; if (param.no_bias) { MKLDNNStream::Instance().RegisterPrim(mkldnn::convolution_backward_weights( - bwdWeights_pd, *out_grad_mem, *data_mem, *in_grad_weight)); + bwdWeights_pd, *out_grad_mem, *data_mem, *in_grad_weight.second)); } else { - in_grad_bias = const_cast(in_grad[deconv::kBias]).CreateMKLDNNData( - bwdWeights_pd.diff_bias_primitive_desc()); - if (in_grad_bias == nullptr) { - in_grad_bias = CreateMKLDNNMem(bwdWeights_pd.diff_bias_primitive_desc()); - copy_back_bias = true; - } + in_grad_bias = CreateMKLDNNMem(in_grad[deconv::kBias], + bwdWeights_pd.diff_bias_primitive_desc(), req[deconv::kBias]); MKLDNNStream::Instance().RegisterPrim(mkldnn::convolution_backward_weights( - bwdWeights_pd, *out_grad_mem, *data_mem, *in_grad_weight, *in_grad_bias)); + bwdWeights_pd, *out_grad_mem, *data_mem, *in_grad_weight.second, + *in_grad_bias.second)); } - if (copy_back_weight) - const_cast(in_grad[deconv::kWeight]).CopyFrom(*in_grad_weight); - if (copy_back_bias) - const_cast(in_grad[deconv::kBias]).CopyFrom(*in_grad_bias); + CommitOutput(in_grad[deconv::kWeight], in_grad_weight); + CommitOutput(in_grad[deconv::kBias], in_grad_bias); } MKLDNNStream::Instance().Submit(); } diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc index 6e73fd50f95d..ae80dd8f9095 100644 --- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc +++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc @@ -93,23 +93,17 @@ void MKLDNNFC_Forward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, auto data_mem = in_data[fullc::kData].GetMKLDNNDataReorder(ipFwd_pd.src_primitive_desc()); auto weight_mem = in_data[fullc::kWeight].GetMKLDNNDataReorder( ipFwd_pd.weights_primitive_desc()); - auto out_mem = const_cast(out_data[fullc::kOut]).CreateMKLDNNData( - ipFwd_pd.dst_primitive_desc()); - bool copy_back = false; - if (out_mem == nullptr) { - out_mem = CreateMKLDNNMem(ipFwd_pd.dst_primitive_desc()); - copy_back = true; - } + auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut], + ipFwd_pd.dst_primitive_desc(), req[fullc::kOut]); if (param.no_bias) { MKLDNNStream::Instance().RegisterPrim(mkldnn::inner_product_forward( - ipFwd_pd, *data_mem, *weight_mem, *out_mem)); + ipFwd_pd, *data_mem, *weight_mem, *out_mem.second)); } else { auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder(ipFwd_pd.bias_primitive_desc()); MKLDNNStream::Instance().RegisterPrim(mkldnn::inner_product_forward(ipFwd_pd, - *data_mem, *weight_mem, *bias_mem, *out_mem)); + *data_mem, *weight_mem, *bias_mem, *out_mem.second)); } - if (copy_back) - const_cast(out_data[fullc::kOut]).CopyFrom(*out_mem); + CommitOutput(out_data[fullc::kOut], out_mem); MKLDNNStream::Instance().Submit(); } @@ -131,17 +125,11 @@ void MKLDNNFC_Backward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, ipBwdData_pd.diff_dst_primitive_desc()); auto weight_mem = inputs[fullc::kWeight + 1].GetMKLDNNDataReorder( ipBwdData_pd.weights_primitive_desc()); - auto in_grad_mem = const_cast(in_grad[fullc::kData]).CreateMKLDNNData( - ipBwdData_pd.diff_src_primitive_desc()); - bool copy_back = false; - if (in_grad_mem == nullptr) { - in_grad_mem = CreateMKLDNNMem(ipBwdData_pd.diff_src_primitive_desc()); - copy_back = true; - } + auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData], + ipBwdData_pd.diff_src_primitive_desc(), req[fullc::kData]); MKLDNNStream::Instance().RegisterPrim(mkldnn::inner_product_backward_data( - ipBwdData_pd, *out_grad_mem, *weight_mem, *in_grad_mem)); - if (copy_back) - const_cast(in_grad[fullc::kData]).CopyFrom(*in_grad_mem); + ipBwdData_pd, *out_grad_mem, *weight_mem, *in_grad_mem.second)); + CommitOutput(in_grad[fullc::kData], in_grad_mem); } if (req[fullc::kWeight]) { mkldnn::inner_product_backward_weights::primitive_desc ipBwdWeights_pd @@ -152,32 +140,21 @@ void MKLDNNFC_Backward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, ipBwdWeights_pd.diff_dst_primitive_desc()); auto data_mem = inputs[fullc::kData + 1].GetMKLDNNDataReorder( ipBwdWeights_pd.src_primitive_desc()); - auto in_grad_weight = const_cast(in_grad[fullc::kWeight]).CreateMKLDNNData( - ipBwdWeights_pd.diff_weights_primitive_desc()); - bool copy_back_weight = false; - bool copy_back_bias = false; - if (in_grad_weight == nullptr) { - in_grad_weight = CreateMKLDNNMem(ipBwdWeights_pd.diff_weights_primitive_desc()); - copy_back_weight = true; - } - mkldnn_mem_const_ptr in_grad_bias; + auto in_grad_weight = CreateMKLDNNMem(in_grad[fullc::kWeight], + ipBwdWeights_pd.diff_weights_primitive_desc(), req[fullc::kWeight]); + mkldnn_output_t in_grad_bias; if (param.no_bias) { MKLDNNStream::Instance().RegisterPrim(mkldnn::inner_product_backward_weights( - ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight)); + ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second)); } else { - in_grad_bias = const_cast(in_grad[fullc::kBias]).CreateMKLDNNData( - ipBwdWeights_pd.diff_bias_primitive_desc()); - if (in_grad_bias == nullptr) { - in_grad_bias = CreateMKLDNNMem(ipBwdWeights_pd.diff_bias_primitive_desc()); - copy_back_bias = true; - } + in_grad_bias = CreateMKLDNNMem(in_grad[fullc::kBias], + ipBwdWeights_pd.diff_bias_primitive_desc(), req[fullc::kBias]); MKLDNNStream::Instance().RegisterPrim(mkldnn::inner_product_backward_weights( - ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight, *in_grad_bias)); + ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second, + *in_grad_bias.second)); } - if (copy_back_weight) - const_cast(in_grad[fullc::kWeight]).CopyFrom(*in_grad_weight); - if (copy_back_bias) - const_cast(in_grad[fullc::kBias]).CopyFrom(*in_grad_bias); + CommitOutput(in_grad[fullc::kWeight], in_grad_weight); + CommitOutput(in_grad[fullc::kBias], in_grad_bias); } MKLDNNStream::Instance().Submit(); } diff --git a/src/operator/nn/mkldnn/mkldnn_relu-inl.h b/src/operator/nn/mkldnn/mkldnn_relu-inl.h index affb29ed7750..25ad61a5d68c 100644 --- a/src/operator/nn/mkldnn/mkldnn_relu-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_relu-inl.h @@ -76,9 +76,7 @@ void MKLDNNRelu_Backward(const OpContext &ctx, const NDArray &out_grad, return; } - // TODO we need to handle req std::shared_ptr diff_dst_memory = out_grad.GetMKLDNNData(); - // TODO shouldn't it be out_data? std::shared_ptr input_mem = in_data.GetMKLDNNData(); mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc(); mkldnn::memory::desc data_md = data_mpd.desc(); @@ -92,11 +90,11 @@ void MKLDNNRelu_Backward(const OpContext &ctx, const NDArray &out_grad, mkldnn::eltwise_backward::desc bw_desc(mkldnn::eltwise_relu, diff_md, data_md, alpha); mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine, fw_pdesc); - std::shared_ptr diff_src_memory - = const_cast(in_grad).CreateMKLDNNData(bw_pdesc.diff_src_primitive_desc()); + auto diff_src_memory = CreateMKLDNNMem(in_grad, bw_pdesc.diff_src_primitive_desc(), req); MKLDNNStream &stream = MKLDNNStream::Instance(); stream.RegisterPrim(mkldnn::eltwise_backward(bw_pdesc, *input_mem, - *diff_dst_memory, *diff_src_memory)); + *diff_dst_memory, *diff_src_memory.second)); + CommitOutput(in_grad, diff_src_memory); stream.Submit(); } diff --git a/src/operator/nn/mkldnn/mkldnn_sum.cc b/src/operator/nn/mkldnn/mkldnn_sum.cc new file mode 100644 index 000000000000..61ec1bbc4199 --- /dev/null +++ b/src/operator/nn/mkldnn/mkldnn_sum.cc @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file mkldnn_sum.cc + * \brief + * \author Da Zheng +*/ +#include + +#include "./mkldnn_ops-inl.h" +#include "./mkldnn_base-inl.h" + +#if MXNET_USE_MKLDNN == 1 +namespace mxnet { +namespace op { + +void Sum(const mkldnn::memory &arr1, const mkldnn::memory &arr2, + const mkldnn::memory &out) { + std::vector input_pds(2); + std::vector scales(2); + std::vector inputs; + input_pds[0] = arr1.get_primitive_desc(); + input_pds[1] = arr2.get_primitive_desc(); + CHECK(input_pds[0] == input_pds[1]); + scales[0] = 1; + scales[1] = 1; + inputs.push_back(arr1); + inputs.push_back(arr2); + mkldnn::sum::primitive_desc sum_pd(scales, input_pds); + MKLDNNStream::Instance().RegisterPrim(mkldnn::sum(sum_pd, inputs, out)); +} + +} +} +#endif