From 06c51e974cbaa70939e6ac82c98e5ee1b9a6bdf3 Mon Sep 17 00:00:00 2001 From: shufan wu Date: Thu, 18 Apr 2019 22:02:25 +0800 Subject: [PATCH] move the position of MKL_Compute --- src/operator/tensor/elemwise_unary_op.h | 43 ++++++++++++------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index 82a9aa9fe7a1..279efcf97084 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -266,6 +266,27 @@ class UnaryOp : public OpBase { } #if MSHADOW_USE_MKL == 1 + template + static void MKL_Compute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if (req[0] == kNullOp) return; + auto type_flag = inputs[0].type_flag_; + size_t input_size = inputs[0].Size(); + if ((req[0] == kWriteTo || req[0] == kWriteInplace) && + mkl_func::check_size(input_size) && + mkl_func::check_type(type_flag)) { + // set DType as float or double according to type_flag + MSHADOW_SGL_DBL_TYPE_SWITCH(type_flag, DType, { + MKL_OP::Vectorize(input_size, inputs[0].dptr(), outputs[0].dptr()); + }); + } else { + Compute(attrs, ctx, inputs, req, outputs); + } + } + template static void MKL_ComputeEx(const nnvm::NodeAttrs& attrs, const OpContext& ctx, @@ -375,28 +396,6 @@ class UnaryOp : public OpBase { } } -#if MSHADOW_USE_MKL == 1 - template - static void MKL_Compute(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - if (req[0] == kNullOp) return; - auto type_flag = inputs[0].type_flag_; - size_t input_size = inputs[0].Size(); - if ((req[0] == kWriteTo || req[0] == kWriteInplace) && - mkl_func::check_size(input_size) && - mkl_func::check_type(type_flag)) { - // set DType as float or double according to type_flag - MSHADOW_SGL_DBL_TYPE_SWITCH(type_flag, DType, { - MKL_OP::Vectorize(input_size, inputs[0].dptr(), outputs[0].dptr()); - }); - } else { - Compute(attrs, ctx, inputs, req, outputs); - } - } -#endif // MSHADOW_USE_MKL == 1 }; /*! \brief Map legacy unary_bwd to backward_grad */