Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Integrating the MKL VML functions to MXNET to speed-up the (element-w…
Browse files Browse the repository at this point in the history
…ised) mathematic computation (#14893)

* mkl_func test with erf&log op, build success~

* fix lint and build issues

* Try to add support to sparse array

* fix build

* add functions

* Fix review comments

* remove unecessary code

* Update test case

* minor fix

* move the position of MKL_Compute

* mkl_func test with erf&log op, build success~

* fix lint and build issues

* Try to add support to sparse array

* fix build

* Fix review comments

* remove unecessary code

* Update test case

* minor fix

* add functions

* move the position of MKL_Compute

* fix cpplint

* cpp lint

* trigger ci

* address comments

* coding style

* enable layernorm

* fix windows build

* revert changes to FComputeEx

* int -> index_t

* remove workspace

* fix lint

* clean code
  • Loading branch information
juliusshufan authored and pengzhao-intel committed May 22, 2019
1 parent 37f5315 commit b0be6c5
Show file tree
Hide file tree
Showing 6 changed files with 398 additions and 50 deletions.
165 changes: 165 additions & 0 deletions src/operator/mkl_functions-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file mkl_functions-inl.h
* \brief Wrapper for MKL VML functions
* \author Tao Lv, Shufan Wu
*/
#ifndef MXNET_OPERATOR_MKL_FUNCTIONS_INL_H_
#define MXNET_OPERATOR_MKL_FUNCTIONS_INL_H_

#if MSHADOW_USE_MKL == 1
#include "mkl_vml.h"

namespace mxnet {
namespace op {
namespace mkl_func {

MSHADOW_XINLINE
static bool check_size(const size_t n) {
const size_t MKL_INT_MAX = (sizeof(MKL_INT) == sizeof(int)) ? INT_MAX : LLONG_MAX;
return (n <= MKL_INT_MAX);
}

MSHADOW_XINLINE
static bool check_type(const int t) {
return (t == mshadow::kFloat32 || t == mshadow::kFloat64);
}

#define MXNET_MKL_UNARY_MATH_FUNC(name, func) \
struct name { \
MSHADOW_XINLINE static void Vectorize(const index_t n, const float *src, float *dst) { \
vs##func(static_cast<MKL_INT>(n), src, dst); \
} \
MSHADOW_XINLINE static void Vectorize(const index_t n, const double *src, double *dst) { \
vd##func(static_cast<MKL_INT>(n), src, dst); \
} \
};

#define MXNET_MKL_BINARY_MATH_FUNC(name, func) \
struct name { \
MSHADOW_XINLINE static void Vectorize(const index_t n, \
const float *a, \
const float *b, \
float *c) { \
vs##func(static_cast<MKL_INT>(n), a, b, c); \
} \
MSHADOW_XINLINE static void Vectorize(const index_t n, \
const double *a, \
const double *b, \
double *c) { \
vd##func(static_cast<MKL_INT>(n), a, b, c); \
} \
};

MXNET_MKL_UNARY_MATH_FUNC(erf, Erf);
MXNET_MKL_UNARY_MATH_FUNC(exp, Exp);
MXNET_MKL_UNARY_MATH_FUNC(exp2, Exp2);
MXNET_MKL_UNARY_MATH_FUNC(exp10, Exp10);
MXNET_MKL_UNARY_MATH_FUNC(expm1, Expm1);
MXNET_MKL_UNARY_MATH_FUNC(log, Ln);
MXNET_MKL_UNARY_MATH_FUNC(log2, Log2);
MXNET_MKL_UNARY_MATH_FUNC(log10, Log10);
MXNET_MKL_UNARY_MATH_FUNC(log1p, Log1p);

MXNET_MKL_UNARY_MATH_FUNC(sin, Sin);
MXNET_MKL_UNARY_MATH_FUNC(cos, Cos);
MXNET_MKL_UNARY_MATH_FUNC(tan, Tan);
MXNET_MKL_UNARY_MATH_FUNC(asin, Asin);
MXNET_MKL_UNARY_MATH_FUNC(acos, Acos);
MXNET_MKL_UNARY_MATH_FUNC(atan, Atan);

MXNET_MKL_UNARY_MATH_FUNC(sinh, Sinh);
MXNET_MKL_UNARY_MATH_FUNC(cosh, Cosh);
MXNET_MKL_UNARY_MATH_FUNC(tanh, Tanh);
MXNET_MKL_UNARY_MATH_FUNC(asinh, Asinh);
MXNET_MKL_UNARY_MATH_FUNC(acosh, Acosh);
MXNET_MKL_UNARY_MATH_FUNC(atanh, Atanh);

MXNET_MKL_UNARY_MATH_FUNC(sqrt, Sqrt);
MXNET_MKL_UNARY_MATH_FUNC(abs, Abs);
MXNET_MKL_UNARY_MATH_FUNC(cbrt, Cbrt);
MXNET_MKL_UNARY_MATH_FUNC(round, Round);
MXNET_MKL_UNARY_MATH_FUNC(ceil, Ceil);
MXNET_MKL_UNARY_MATH_FUNC(floor, Floor);
MXNET_MKL_UNARY_MATH_FUNC(trunc, Trunc);

MXNET_MKL_UNARY_MATH_FUNC(lgamma, LGamma);
MXNET_MKL_UNARY_MATH_FUNC(tgamma, TGamma);
MXNET_MKL_UNARY_MATH_FUNC(square, Sqr);

MXNET_MKL_BINARY_MATH_FUNC(add, Add);
MXNET_MKL_BINARY_MATH_FUNC(sub, Sub);
MXNET_MKL_BINARY_MATH_FUNC(mul, Mul);
MXNET_MKL_BINARY_MATH_FUNC(pow, Pow);
MXNET_MKL_BINARY_MATH_FUNC(hypot, Hypot);

template <typename DType>
MSHADOW_XINLINE static void sum_(index_t n, DType *in, DType *dst) {
DType sum = 0.0f;
for (index_t i = 0; i < n; i++)
sum += in[i];

dst[0] = sum;
}

// LayerNorm on the last dimension
template <typename DType>
MSHADOW_XINLINE static void LayerNormLastDim(index_t m,
index_t n,
DType *a,
DType *b,
DType *gamma,
DType *beta,
DType *mean,
DType *var,
DType eps) {
auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
#pragma omp parallel for num_threads(nthreads)
for (index_t i = 0; i < m; i++) {
DType* in_offset = a + i * n;
DType* out_offset = b + i * n;

sum_(n, in_offset, &(mean[i]));
mean[i] /= n;
var[i] = 0.0f;
#if !defined(_MSC_VER)
#pragma omp simd
#endif
for (index_t j = 0; j < n; j++) {
out_offset[j] = in_offset[j] - mean[i];
var[i] += out_offset[j] * out_offset[j];
}
var[i] = math::sqrt(var[i] / n + eps);
#if !defined(_MSC_VER)
#pragma omp simd
#endif
for (index_t j = 0; j < n; j++) {
out_offset[j] = out_offset[j] * gamma[j] / var[i] + beta[j];
}
}
}

} // namespace mkl_func
} // namespace op
} // namespace mxnet
#endif // MSHADOW_USE_MKL == 1
#endif // MXNET_OPERATOR_MKL_FUNCTIONS_INL_H_
9 changes: 5 additions & 4 deletions src/operator/nn/layer_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ struct LayerNormParam : public dmlc::Parameter<LayerNormParam> {
}
};

static int GetRealAxis(int axis, int ndim) {
return axis < 0 ? (axis + ndim) : axis;
}

template<typename xpu>
void LayerNormCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx, const std::vector<TBlob>& inputs,
Expand All @@ -79,10 +83,7 @@ void LayerNormComputeGeneral(const nnvm::NodeAttrs& attrs,
const LayerNormParam& param = nnvm::get<LayerNormParam>(attrs.parsed);
if (req[0] == kNullOp) return;
CHECK_NE(req[0], kAddTo);
int axis = param.axis;
if (axis < 0) {
axis += static_cast<int>(inputs[0].ndim());
}
int axis = GetRealAxis(param.axis, inputs[0].ndim());
CHECK(axis >= 0 && axis < inputs[0].ndim()) << "Channel axis out of range: " << param.axis;
CHECK_EQ(inputs.size(), 3U);
Stream<xpu> *s = ctx.get_stream<xpu>();
Expand Down
58 changes: 53 additions & 5 deletions src/operator/nn/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
#include <nnvm/op_attr_types.h>
#include "../elemwise_op_common.h"

#if MSHADOW_USE_MKL == 1
#include "../mkl_functions-inl.h"
#endif

namespace mxnet {
namespace op {

Expand All @@ -39,10 +43,7 @@ static bool LayerNormShape(const nnvm::NodeAttrs& attrs,
using namespace mshadow;
CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]";
const mxnet::TShape &dshape = in_shape->at(layernorm::kData);
int axis = param.axis;
if (axis < 0) {
axis += dshape.ndim();
}
int axis = GetRealAxis(param.axis, dshape.ndim());
CHECK(axis >= 0 && axis < dshape.ndim())
<< "Channel axis out of range: axis=" << param.axis;

Expand All @@ -64,7 +65,6 @@ static bool LayerNormShape(const nnvm::NodeAttrs& attrs,
return true;
}


template<>
void LayerNormCompute<cpu>(const nnvm::NodeAttrs& attrs,
const OpContext& ctx, const std::vector<TBlob>& inputs,
Expand All @@ -73,6 +73,50 @@ void LayerNormCompute<cpu>(const nnvm::NodeAttrs& attrs,
return LayerNormComputeGeneral<cpu>(attrs, ctx, inputs, req, outputs);
}

#if MSHADOW_USE_MKL == 1
void LayerNormComputeMKL(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mshadow;
const LayerNormParam& param = nnvm::get<LayerNormParam>(attrs.parsed);
if (req[0] == kNullOp) return;
CHECK_NE(req[0], kAddTo);
CHECK_EQ(inputs.size(), 3U);
int axis = GetRealAxis(param.axis, inputs[0].ndim());

if (axis == (inputs[layernorm::kData].ndim() - 1) ||
(inputs[0].type_flag_ != kFloat32 && inputs[0].type_flag_ != kFloat64)) {
// Compute necessary data for the reduce operation.
mxnet::TShape red_src_shape, red_dst_shape;
BroadcastReduceShapeCompact(inputs[layernorm::kData].shape_, outputs[layernorm::kMean].shape_,
&red_src_shape, &red_dst_shape);
const TBlob in_data = inputs[layernorm::kData].reshape(red_src_shape);
const TBlob mean_data = outputs[layernorm::kMean].reshape(red_dst_shape);
const TBlob std_data = outputs[layernorm::kStd].reshape(red_dst_shape);
const int outter_size = red_dst_shape.Size();
const int channel_size = red_src_shape.Size() / red_dst_shape.Size();

// call
MSHADOW_SGL_DBL_TYPE_SWITCH(in_data.type_flag_, DType, {
mkl_func::LayerNormLastDim(outter_size, channel_size,
in_data.dptr<DType>(),
outputs[layernorm::kOut].dptr<DType>(),
inputs[layernorm::kGamma].dptr<DType>(),
inputs[layernorm::kBeta].dptr<DType>(),
outputs[layernorm::kMean].dptr<DType>(),
outputs[layernorm::kStd].dptr<DType>(),
static_cast<DType>(param.eps));
});
} else {
// fallback
LayerNormCompute<cpu>(attrs, ctx, inputs, req, outputs);
}
}
#endif


template<>
void LayerNormGradCompute<cpu>(const nnvm::NodeAttrs& attrs,
const OpContext& ctx, const std::vector<TBlob>& inputs,
Expand Down Expand Up @@ -126,7 +170,11 @@ axis to be the last item in the input shape.
})
.set_attr<mxnet::FInferShape>("FInferShape", LayerNormShape)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 3>)
#if MSHADOW_USE_MKL == 1
.set_attr<FCompute>("FCompute<cpu>", LayerNormComputeMKL)
#else
.set_attr<FCompute>("FCompute<cpu>", LayerNormCompute<cpu>)
#endif
.set_attr<nnvm::FGradient>("FGradient", [](const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) {
std::vector<nnvm::NodeEntry> heads;
Expand Down
Loading

0 comments on commit b0be6c5

Please sign in to comment.