From b0be6c50ae686a59c340c56b9c11493a5c84ff72 Mon Sep 17 00:00:00 2001 From: Shufan <33112206+juliusshufan@users.noreply.github.com> Date: Thu, 23 May 2019 07:17:38 +0800 Subject: [PATCH] Integrating the MKL VML functions to MXNET to speed-up the (element-wised) mathematic computation (#14893) * mkl_func test with erf&log op, build success~ * fix lint and build issues * Try to add support to sparse array * fix build * add functions * Fix review comments * remove unecessary code * Update test case * minor fix * move the position of MKL_Compute * mkl_func test with erf&log op, build success~ * fix lint and build issues * Try to add support to sparse array * fix build * Fix review comments * remove unecessary code * Update test case * minor fix * add functions * move the position of MKL_Compute * fix cpplint * cpp lint * trigger ci * address comments * coding style * enable layernorm * fix windows build * revert changes to FComputeEx * int -> index_t * remove workspace * fix lint * clean code --- src/operator/mkl_functions-inl.h | 165 ++++++++++++++++++ src/operator/nn/layer_norm-inl.h | 9 +- src/operator/nn/layer_norm.cc | 58 +++++- src/operator/tensor/elemwise_unary_op.h | 118 ++++++++----- .../tensor/elemwise_unary_op_basic.cc | 50 +++++- tests/python/gpu/test_operator_gpu.py | 48 +++++ 6 files changed, 398 insertions(+), 50 deletions(-) create mode 100644 src/operator/mkl_functions-inl.h diff --git a/src/operator/mkl_functions-inl.h b/src/operator/mkl_functions-inl.h new file mode 100644 index 000000000000..608034732e0e --- /dev/null +++ b/src/operator/mkl_functions-inl.h @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file mkl_functions-inl.h + * \brief Wrapper for MKL VML functions + * \author Tao Lv, Shufan Wu +*/ +#ifndef MXNET_OPERATOR_MKL_FUNCTIONS_INL_H_ +#define MXNET_OPERATOR_MKL_FUNCTIONS_INL_H_ + +#if MSHADOW_USE_MKL == 1 +#include "mkl_vml.h" + +namespace mxnet { +namespace op { +namespace mkl_func { + +MSHADOW_XINLINE +static bool check_size(const size_t n) { + const size_t MKL_INT_MAX = (sizeof(MKL_INT) == sizeof(int)) ? INT_MAX : LLONG_MAX; + return (n <= MKL_INT_MAX); +} + +MSHADOW_XINLINE +static bool check_type(const int t) { + return (t == mshadow::kFloat32 || t == mshadow::kFloat64); +} + +#define MXNET_MKL_UNARY_MATH_FUNC(name, func) \ +struct name { \ + MSHADOW_XINLINE static void Vectorize(const index_t n, const float *src, float *dst) { \ + vs##func(static_cast(n), src, dst); \ + } \ + MSHADOW_XINLINE static void Vectorize(const index_t n, const double *src, double *dst) { \ + vd##func(static_cast(n), src, dst); \ + } \ +}; + +#define MXNET_MKL_BINARY_MATH_FUNC(name, func) \ +struct name { \ + MSHADOW_XINLINE static void Vectorize(const index_t n, \ + const float *a, \ + const float *b, \ + float *c) { \ + vs##func(static_cast(n), a, b, c); \ + } \ + MSHADOW_XINLINE static void Vectorize(const index_t n, \ + const double *a, \ + const double *b, \ + double *c) { \ + vd##func(static_cast(n), a, b, c); \ + } \ +}; + +MXNET_MKL_UNARY_MATH_FUNC(erf, Erf); +MXNET_MKL_UNARY_MATH_FUNC(exp, Exp); +MXNET_MKL_UNARY_MATH_FUNC(exp2, Exp2); +MXNET_MKL_UNARY_MATH_FUNC(exp10, Exp10); +MXNET_MKL_UNARY_MATH_FUNC(expm1, Expm1); +MXNET_MKL_UNARY_MATH_FUNC(log, Ln); +MXNET_MKL_UNARY_MATH_FUNC(log2, Log2); +MXNET_MKL_UNARY_MATH_FUNC(log10, Log10); +MXNET_MKL_UNARY_MATH_FUNC(log1p, Log1p); + +MXNET_MKL_UNARY_MATH_FUNC(sin, Sin); +MXNET_MKL_UNARY_MATH_FUNC(cos, Cos); +MXNET_MKL_UNARY_MATH_FUNC(tan, Tan); +MXNET_MKL_UNARY_MATH_FUNC(asin, Asin); +MXNET_MKL_UNARY_MATH_FUNC(acos, Acos); +MXNET_MKL_UNARY_MATH_FUNC(atan, Atan); + +MXNET_MKL_UNARY_MATH_FUNC(sinh, Sinh); +MXNET_MKL_UNARY_MATH_FUNC(cosh, Cosh); +MXNET_MKL_UNARY_MATH_FUNC(tanh, Tanh); +MXNET_MKL_UNARY_MATH_FUNC(asinh, Asinh); +MXNET_MKL_UNARY_MATH_FUNC(acosh, Acosh); +MXNET_MKL_UNARY_MATH_FUNC(atanh, Atanh); + +MXNET_MKL_UNARY_MATH_FUNC(sqrt, Sqrt); +MXNET_MKL_UNARY_MATH_FUNC(abs, Abs); +MXNET_MKL_UNARY_MATH_FUNC(cbrt, Cbrt); +MXNET_MKL_UNARY_MATH_FUNC(round, Round); +MXNET_MKL_UNARY_MATH_FUNC(ceil, Ceil); +MXNET_MKL_UNARY_MATH_FUNC(floor, Floor); +MXNET_MKL_UNARY_MATH_FUNC(trunc, Trunc); + +MXNET_MKL_UNARY_MATH_FUNC(lgamma, LGamma); +MXNET_MKL_UNARY_MATH_FUNC(tgamma, TGamma); +MXNET_MKL_UNARY_MATH_FUNC(square, Sqr); + +MXNET_MKL_BINARY_MATH_FUNC(add, Add); +MXNET_MKL_BINARY_MATH_FUNC(sub, Sub); +MXNET_MKL_BINARY_MATH_FUNC(mul, Mul); +MXNET_MKL_BINARY_MATH_FUNC(pow, Pow); +MXNET_MKL_BINARY_MATH_FUNC(hypot, Hypot); + +template +MSHADOW_XINLINE static void sum_(index_t n, DType *in, DType *dst) { + DType sum = 0.0f; + for (index_t i = 0; i < n; i++) + sum += in[i]; + + dst[0] = sum; +} + +// LayerNorm on the last dimension +template +MSHADOW_XINLINE static void LayerNormLastDim(index_t m, + index_t n, + DType *a, + DType *b, + DType *gamma, + DType *beta, + DType *mean, + DType *var, + DType eps) { + auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); +#pragma omp parallel for num_threads(nthreads) + for (index_t i = 0; i < m; i++) { + DType* in_offset = a + i * n; + DType* out_offset = b + i * n; + + sum_(n, in_offset, &(mean[i])); + mean[i] /= n; + var[i] = 0.0f; +#if !defined(_MSC_VER) +#pragma omp simd +#endif + for (index_t j = 0; j < n; j++) { + out_offset[j] = in_offset[j] - mean[i]; + var[i] += out_offset[j] * out_offset[j]; + } + var[i] = math::sqrt(var[i] / n + eps); +#if !defined(_MSC_VER) +#pragma omp simd +#endif + for (index_t j = 0; j < n; j++) { + out_offset[j] = out_offset[j] * gamma[j] / var[i] + beta[j]; + } + } +} + +} // namespace mkl_func +} // namespace op +} // namespace mxnet +#endif // MSHADOW_USE_MKL == 1 +#endif // MXNET_OPERATOR_MKL_FUNCTIONS_INL_H_ diff --git a/src/operator/nn/layer_norm-inl.h b/src/operator/nn/layer_norm-inl.h index 456a5cb805ec..7636c9bb8715 100644 --- a/src/operator/nn/layer_norm-inl.h +++ b/src/operator/nn/layer_norm-inl.h @@ -63,6 +63,10 @@ struct LayerNormParam : public dmlc::Parameter { } }; +static int GetRealAxis(int axis, int ndim) { + return axis < 0 ? (axis + ndim) : axis; +} + template void LayerNormCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -79,10 +83,7 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs, const LayerNormParam& param = nnvm::get(attrs.parsed); if (req[0] == kNullOp) return; CHECK_NE(req[0], kAddTo); - int axis = param.axis; - if (axis < 0) { - axis += static_cast(inputs[0].ndim()); - } + int axis = GetRealAxis(param.axis, inputs[0].ndim()); CHECK(axis >= 0 && axis < inputs[0].ndim()) << "Channel axis out of range: " << param.axis; CHECK_EQ(inputs.size(), 3U); Stream *s = ctx.get_stream(); diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc index 5b0aca6910f7..1581f1acb050 100644 --- a/src/operator/nn/layer_norm.cc +++ b/src/operator/nn/layer_norm.cc @@ -27,6 +27,10 @@ #include #include "../elemwise_op_common.h" +#if MSHADOW_USE_MKL == 1 +#include "../mkl_functions-inl.h" +#endif + namespace mxnet { namespace op { @@ -39,10 +43,7 @@ static bool LayerNormShape(const nnvm::NodeAttrs& attrs, using namespace mshadow; CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]"; const mxnet::TShape &dshape = in_shape->at(layernorm::kData); - int axis = param.axis; - if (axis < 0) { - axis += dshape.ndim(); - } + int axis = GetRealAxis(param.axis, dshape.ndim()); CHECK(axis >= 0 && axis < dshape.ndim()) << "Channel axis out of range: axis=" << param.axis; @@ -64,7 +65,6 @@ static bool LayerNormShape(const nnvm::NodeAttrs& attrs, return true; } - template<> void LayerNormCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -73,6 +73,50 @@ void LayerNormCompute(const nnvm::NodeAttrs& attrs, return LayerNormComputeGeneral(attrs, ctx, inputs, req, outputs); } +#if MSHADOW_USE_MKL == 1 +void LayerNormComputeMKL(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + const LayerNormParam& param = nnvm::get(attrs.parsed); + if (req[0] == kNullOp) return; + CHECK_NE(req[0], kAddTo); + CHECK_EQ(inputs.size(), 3U); + int axis = GetRealAxis(param.axis, inputs[0].ndim()); + + if (axis == (inputs[layernorm::kData].ndim() - 1) || + (inputs[0].type_flag_ != kFloat32 && inputs[0].type_flag_ != kFloat64)) { + // Compute necessary data for the reduce operation. + mxnet::TShape red_src_shape, red_dst_shape; + BroadcastReduceShapeCompact(inputs[layernorm::kData].shape_, outputs[layernorm::kMean].shape_, + &red_src_shape, &red_dst_shape); + const TBlob in_data = inputs[layernorm::kData].reshape(red_src_shape); + const TBlob mean_data = outputs[layernorm::kMean].reshape(red_dst_shape); + const TBlob std_data = outputs[layernorm::kStd].reshape(red_dst_shape); + const int outter_size = red_dst_shape.Size(); + const int channel_size = red_src_shape.Size() / red_dst_shape.Size(); + + // call + MSHADOW_SGL_DBL_TYPE_SWITCH(in_data.type_flag_, DType, { + mkl_func::LayerNormLastDim(outter_size, channel_size, + in_data.dptr(), + outputs[layernorm::kOut].dptr(), + inputs[layernorm::kGamma].dptr(), + inputs[layernorm::kBeta].dptr(), + outputs[layernorm::kMean].dptr(), + outputs[layernorm::kStd].dptr(), + static_cast(param.eps)); + }); + } else { + // fallback + LayerNormCompute(attrs, ctx, inputs, req, outputs); + } +} +#endif + + template<> void LayerNormGradCompute(const nnvm::NodeAttrs& attrs, const OpContext& ctx, const std::vector& inputs, @@ -126,7 +170,11 @@ axis to be the last item in the input shape. }) .set_attr("FInferShape", LayerNormShape) .set_attr("FInferType", ElemwiseType<3, 3>) +#if MSHADOW_USE_MKL == 1 +.set_attr("FCompute", LayerNormComputeMKL) +#else .set_attr("FCompute", LayerNormCompute) +#endif .set_attr("FGradient", [](const nnvm::NodePtr& n, const std::vector& ograds) { std::vector heads; diff --git a/src/operator/tensor/elemwise_unary_op.h b/src/operator/tensor/elemwise_unary_op.h index 86e8b0192dd8..458106e02671 100644 --- a/src/operator/tensor/elemwise_unary_op.h +++ b/src/operator/tensor/elemwise_unary_op.h @@ -35,9 +35,10 @@ #include "../mxnet_op.h" #include "../elemwise_op_common.h" #include "../../ndarray/ndarray_function.h" + #if MSHADOW_USE_MKL == 1 -#include "mkl.h" -#endif +#include "../mkl_functions-inl.h" +#endif // MSHADOW_USE_MKL == 1 namespace mxnet { namespace op { @@ -264,6 +265,48 @@ class UnaryOp : public OpBase { } } +#if MSHADOW_USE_MKL == 1 + template + static void MKL_Compute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if (req[0] == kNullOp) return; + auto type_flag = inputs[0].type_flag_; + size_t input_size = inputs[0].Size(); + if ((req[0] == kWriteTo || req[0] == kWriteInplace) && + mkl_func::check_size(input_size) && + mkl_func::check_type(type_flag)) { + // set DType as float or double according to type_flag + MSHADOW_SGL_DBL_TYPE_SWITCH(type_flag, DType, { + MKL_OP::Vectorize(input_size, inputs[0].dptr(), outputs[0].dptr()); + }); + } else { + Compute(attrs, ctx, inputs, req, outputs); + } + } + + template + static void MKL_ComputeEx(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 1U) + << "Invalid input, only one input is allowed"; + CHECK_EQ(outputs.size(), 1U) + << "Invalid output, only one output is allowed"; + CHECK_NE(inputs[0].storage_type(), kDefaultStorage) + << "Operation requires a sparse output storage type"; + CHECK_NE(outputs[0].storage_type(), kDefaultStorage) + << "Operation requires a sparse output storage type"; + if (inputs[0].storage_shape().Size()) { + MapToFCompute(attrs, ctx, inputs, req, outputs, MKL_Compute); + } + } +#endif + template static void ComputeWithHalf2(const nnvm::NodeAttrs &attrs, const OpContext &ctx, @@ -352,43 +395,6 @@ class UnaryOp : public OpBase { LogUnimplementedOp(attrs, ctx, inputs, req, outputs); } } - -#if MSHADOW_USE_MKL == 1 - static inline void MKLLog(MKL_INT size, const float* pIn, float* pOut) { - vsLn(size, pIn, pOut); - } - - static inline void MKLLog(MKL_INT size, const double* pIn, double* pOut) { - vdLn(size, pIn, pOut); - } -#endif - - template - static void LogCompute(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& inputs, - const std::vector& req, - const std::vector& outputs) { - if (req[0] == kNullOp) return; - // if defined MSHADOW_USE_MKL then call mkl log when req is KWriteTo, type_flag - // is mshadow::kFloat32 or mshadow::kFloat64 and data size less than or equal MKL_INT_MAX -#if MSHADOW_USE_MKL == 1 - auto type_flag = inputs[0].type_flag_; - const size_t MKL_INT_MAX = (sizeof(MKL_INT) == sizeof(int)) ? INT_MAX : LLONG_MAX; - size_t input_size = inputs[0].Size(); - if (req[0] == kWriteTo && - input_size <= MKL_INT_MAX && - (type_flag == mshadow::kFloat32 || type_flag == mshadow::kFloat64)) { - MSHADOW_SGL_DBL_TYPE_SWITCH(type_flag, DType, { - MKLLog(input_size, inputs[0].dptr(), outputs[0].dptr()); - }); - } else { - Compute(attrs, ctx, inputs, req, outputs); - } -#else - Compute(attrs, ctx, inputs, req, outputs); -#endif - } }; /*! \brief Map legacy unary_bwd to backward_grad */ @@ -557,7 +563,7 @@ struct ReshapeLikeParam : public dmlc::Parameter { NNVM_REGISTER_OP(__name$) \ .set_num_inputs(1) \ .set_num_outputs(1) \ - .set_attr("FInferShape", ElemwiseShape<1, 1>) \ + .set_attr("FInferShape", ElemwiseShape<1, 1>) \ .set_attr("FInferType", ElemwiseType<1, 1>) \ .set_attr("FInplaceOption", \ [](const NodeAttrs& attrs){ \ @@ -565,6 +571,38 @@ struct ReshapeLikeParam : public dmlc::Parameter { }) \ .add_argument("data", "NDArray-or-Symbol", "The input array.") +#if MSHADOW_USE_MKL == 1 + /*! \bried MKL Unary compute. + * With this macro means mxnet compile with MKL to accelerate math function with mkl. + * Will Register FCompute with UnaryOp::MKL_Compute() to compelet the math function. + */ + #define MXNET_MKL_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(__name$, __xpu$, \ + __kernel$, __mkl_kernel$) \ + MXNET_OPERATOR_REGISTER_UNARY(__name$) \ + MXNET_ADD_SPARSE_OP_ALIAS(__name$) \ + .set_attr("FInferStorageType", ElemwiseStorageType<1, 1, \ + false, true, true>) \ + .set_attr("FCompute<" #__xpu$ ">", UnaryOp::MKL_Compute<__kernel$, __mkl_kernel$>) \ + .set_attr("FComputeEx<" #__xpu$ ">", UnaryOp::ComputeEx<__xpu$, __kernel$>) + + /*! \bried MKL Unary compute. + * With this macro means mxnet compile with MKL to accelerate math function with mkl. + * Will Register FCompute with UnaryOp::MKL_Compute() to compelet the math function. + */ + #define MXNET_MKL_OPERATOR_REGISTER_UNARY_WITH_RSP(__name$, __xpu$, __kernel$, __mkl_kernel$) \ + MXNET_OPERATOR_REGISTER_UNARY(__name$) \ + MXNET_ADD_SPARSE_OP_ALIAS(__name$) \ + .set_attr("FInferStorageType", ElemwiseStorageType<1, 1, \ + false, true, false>) \ + .set_attr("FCompute<" #__xpu$ ">", UnaryOp::MKL_Compute<__kernel$, __mkl_kernel$>) \ + .set_attr("FComputeEx<" #__xpu$ ">", UnaryOp::MKL_ComputeEx<__xpu$, __kernel$>) + + #define MXNET_MKL_OPERATOR_REGISTER_UNARY_WITH_SPARSE_DR(__name$, __xpu$, __kernel$, \ + __mkl_kernel$) \ + MXNET_OPERATOR_REGISTER_UNARY(__name$) \ + .set_attr("FCompute<" #__xpu$ ">", UnaryOp::MKL_Compute<__kernel$, __mkl_kernel$>) +#endif + /*! \brief Unary compute, with FComputeEx for csr and rsp available */ #define MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(__name$, __xpu$, __kernel$) \ MXNET_OPERATOR_REGISTER_UNARY(__name$) \ diff --git a/src/operator/tensor/elemwise_unary_op_basic.cc b/src/operator/tensor/elemwise_unary_op_basic.cc index 1634606ac0ac..77225065d928 100644 --- a/src/operator/tensor/elemwise_unary_op_basic.cc +++ b/src/operator/tensor/elemwise_unary_op_basic.cc @@ -829,6 +829,26 @@ The storage type of ``fix`` output depends upon the input storage type: .set_attr("FGradient", MakeZeroGradNodes); // square +#if MSHADOW_USE_MKL == 1 +MXNET_MKL_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(square, cpu, mshadow_op::square, mkl_func::square) +.describe(R"code(Returns element-wise squared value of the input. + +.. math:: + square(x) = x^2 + +Example:: + + square([2, 3, 4]) = [4, 9, 16] + +The storage type of ``square`` output depends upon the input storage type: + + - square(default) = default + - square(row_sparse) = row_sparse + - square(csr) = csr + +)code" ADD_FILELINE) +.set_attr("FGradient", ElemwiseGradUseIn{"_backward_square"}); +#else MXNET_OPERATOR_REGISTER_UNARY_WITH_RSP_CSR(square, cpu, mshadow_op::square) .describe(R"code(Returns element-wise squared value of the input. @@ -847,6 +867,7 @@ The storage type of ``square`` output depends upon the input storage type: )code" ADD_FILELINE) .set_attr("FGradient", ElemwiseGradUseIn{"_backward_square"}); +#endif MXNET_OPERATOR_REGISTER_BINARY_WITH_SPARSE_CPU(_backward_square, unary_bwd); @@ -926,9 +947,14 @@ Example:: erf([0, -1., 10.]) = [0., -0.8427, 1.] )code" ADD_FILELINE) +#if MSHADOW_USE_MKL == 1 +.set_attr("FCompute", UnaryOp::MKL_Compute) +#else .set_attr("FCompute", UnaryOp::Compute) +#endif // MSHADOW_USE_MKL == 1 .set_attr("FGradient", ElemwiseGradUseIn{"_backward_erf"}); + MXNET_OPERATOR_REGISTER_BINARY(_backward_erf) .set_attr("FCompute", ElemwiseBinaryOp::Compute>); @@ -970,6 +996,23 @@ MXNET_OPERATOR_REGISTER_BINARY(_backward_rcbrt) unary_bwd>); // exp +#if MSHADOW_USE_MKL == 1 +MXNET_MKL_OPERATOR_REGISTER_UNARY_WITH_SPARSE_DR(exp, cpu, mshadow_op::exp, mkl_func::exp) +MXNET_ADD_SPARSE_OP_ALIAS(exp) +.describe(R"code(Returns element-wise exponential value of the input. + +.. math:: + exp(x) = e^x \approx 2.718^x + +Example:: + + exp([0, 1, 2]) = [1., 2.71828175, 7.38905621] + +The storage type of ``exp`` output is always dense + +)code" ADD_FILELINE) +.set_attr("FGradient", ElemwiseGradUseOut{"_mul"}); +#else MXNET_OPERATOR_REGISTER_UNARY_WITH_SPARSE_DR(exp, cpu, mshadow_op::exp) MXNET_ADD_SPARSE_OP_ALIAS(exp) .describe(R"code(Returns element-wise exponential value of the input. @@ -985,6 +1028,7 @@ The storage type of ``exp`` output is always dense )code" ADD_FILELINE) .set_attr("FGradient", ElemwiseGradUseOut{"_mul"}); +#endif // log MXNET_OPERATOR_REGISTER_UNARY(log) @@ -996,7 +1040,11 @@ The natural logarithm is logarithm in base *e*, so that ``log(exp(x)) = x`` The storage type of ``log`` output is always dense )code" ADD_FILELINE) -.set_attr("FCompute", UnaryOp::LogCompute) +#if MSHADOW_USE_MKL == 1 +.set_attr("FCompute", UnaryOp::MKL_Compute) +#else +.set_attr("FCompute", UnaryOp::Compute) +#endif // MSHADOW_USE_MKL == 1 .set_attr("FGradient", ElemwiseGradUseIn{"_backward_log"}); // log10 diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 2a1583ed639e..710686da9e7c 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -2201,6 +2201,54 @@ def test_context_num_gpus(): # Test that num_gpus reports at least one GPU, as the test is run on a GPU host. assert mx.context.num_gpus() > 0 +def math_log(shape, dtype, check_value): + np_x = np.random.rand(*tuple(shape)) + x = mx.nd.array(np_x, dtype=dtype) + y = mx.nd.log(data=x) + if check_value: + x_ = x.as_in_context(mx.cpu()) + y_ = mx.nd.log(data=x_) + assert_almost_equal(y.asnumpy(), y_.asnumpy()) + +def math_erf(shape, dtype, check_value): + np_x = np.random.rand(*tuple(shape)) + x = mx.nd.array(np_x, dtype=dtype) + y = mx.nd.erf(data=x) + if check_value: + x_ = x.as_in_context(mx.cpu()) + y_ = mx.nd.erf(data=x_) + assert_almost_equal(y.asnumpy(), y_.asnumpy()) + +def math_square(shape, dtype, check_value): + np_x = np.random.rand(*tuple(shape)) + x = mx.nd.array(np_x, dtype=dtype) + y = mx.nd.square(data=x) + if check_value: + x_ = x.as_in_context(mx.cpu()) + y_ = mx.nd.square(data=x_) + assert_almost_equal(y.asnumpy(), y_.asnumpy()) + +def run_math(op, shape, dtype="float32", check_value=True): + run_num = 10 + for i in range(run_num): + if op == 'log': + math_log(shape=shape, dtype=dtype, check_value=check_value) + elif op == 'erf': + math_erf(shape=shape, dtype=dtype, check_value=check_value) + elif op == 'square': + math_square(shape=shape, dtype=dtype, check_value=check_value) + +@with_seed() +def test_math(): + ops = ['log', 'erf', 'square'] + check_value= True + shape_lst = [[1000], [100,1000], [10,100,100], [10,100,100,100]] + dtypes = ["float32", "float64"] + for shape in shape_lst: + for dtype in dtypes: + for op in ops: + run_math(op, shape, dtype, check_value=check_value) + if __name__ == '__main__': import nose nose.runmodule()