Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[mkldnn-v1.0] Add MKL-DNN softmax (#16246)
Browse files Browse the repository at this point in the history
* add mkldnn softmax

* trigger CI
  • Loading branch information
rongzha1 authored and pengzhao-intel committed Sep 30, 2019
1 parent 4fba4c3 commit a559760
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 26 deletions.
10 changes: 5 additions & 5 deletions src/operator/nn/mkldnn/mkldnn_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,6 @@ void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);

/* For softmax */
void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,
const NDArray &out_data);

/* For softmax_output */
void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
Expand Down Expand Up @@ -133,6 +128,11 @@ void MKLDNNLeakyReluBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray>& inputs, const OpReqType &req,
const NDArray &output);

/* For softmax */
void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,
const NDArray &out_data);

void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2,
const mkldnn::memory &out);

Expand Down
42 changes: 24 additions & 18 deletions src/operator/nn/mkldnn/mkldnn_softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,30 @@
#include "./mkldnn_ops-inl.h"
#include "./mkldnn_base-inl.h"

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
namespace mxnet {
namespace op {

static mkldnn::softmax_forward::primitive_desc GetSoftmaxFwdPd(
bool is_train, const int axis,
const mkldnn::memory &input_mem) {
mkldnn::memory::desc data_md = input_mem.get_desc();
auto cpu_engine = CpuEngine::Get()->get_engine();
auto prop = is_train ? mkldnn::prop_kind::forward_training
: mkldnn::prop_kind::forward_scoring;
auto desc = mkldnn::softmax_forward::desc(prop, data_md, axis);
return mkldnn::softmax_forward::primitive_desc(desc, cpu_engine);
}


bool SupportMKLDNNSoftmax(const SoftmaxParam &param,
const NDArray &data,
const NDArray &output) {
// MKLDNN does not support temperature argument in their softmax function
// now. Need update this once they start to support it.
const int ndim = data.shape().ndim();
const int in_dtype = data.dtype();
const int out_dtype = output.dtype();

const int axis = CheckAxis(param.axis, ndim);
// MKLDNN does not support temperature argument in their softmax function
// now. Need update this once they start to support it.
Expand All @@ -48,21 +61,12 @@ bool SupportMKLDNNSoftmax(const SoftmaxParam &param,
axis != (ndim - 1)) {
return false;
}

// only supports ndim = 1, 2, 3, 4 for now
return (ndim >= 1 && ndim <= 4);
}

static mkldnn::softmax_forward::primitive_desc GetSoftmaxFwdPd(const int axis,
const bool is_train,
const mkldnn::memory &input) {
auto data_md = input.get_primitive_desc().desc();
auto prop = is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring;
auto desc = mkldnn::softmax_forward::desc(prop, data_md, axis);
auto pd = mkldnn::softmax_forward::primitive_desc(desc, CpuEngine::Get()->get_engine());
return pd;
}

void MKLDNNSoftmaxForward(const nnvm::NodeAttrs &attrs,
void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const NDArray &in_data,
const OpReqType &req,
Expand All @@ -71,21 +75,23 @@ void MKLDNNSoftmaxForward(const nnvm::NodeAttrs &attrs,
// same as the FCompute path, softmax only supports kWriteTo and kWriteInplace for now.
CHECK_NE(req, kAddTo);
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
const int axis = CheckAxis(param.axis, in_data.shape().ndim());

int axis = CheckAxis(param.axis, in_data.shape().ndim());
NDArray data = in_data;
if (in_data.IsView() && in_data.IsMKLDNNData()) {
data = in_data.Reorder2Default();
}

auto data_mem = data.GetMKLDNNData();
auto pd = GetSoftmaxFwdPd(axis, ctx.is_train, *data_mem);
auto out_mem = CreateMKLDNNMem(out_data, pd.dst_primitive_desc(), req);
auto pd = GetSoftmaxFwdPd(ctx.is_train, axis, *data_mem);
auto out_mem = CreateMKLDNNMem(out_data, pd.dst_desc(), req);
MKLDNNStream *stream = MKLDNNStream::Get();
stream->RegisterPrim(mkldnn::softmax_forward(pd, *data_mem, *out_mem.second));
stream->RegisterPrimArgs(pd,
{{MKLDNN_ARG_SRC, *data_mem}, {MKLDNN_ARG_DST, *out_mem.second}});
CommitOutput(out_data, out_mem);
stream->Submit();
}

} // namespace op
} // namespace mxnet
#endif

6 changes: 3 additions & 3 deletions src/operator/nn/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#include "../tensor/elemwise_unary_op.h"
#include "../tensor/elemwise_binary_op.h"
#include "../operator_common.h"
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include "mkldnn/mkldnn_base-inl.h"
#include "mkldnn/mkldnn_ops-inl.h"
#endif
Expand All @@ -35,7 +35,7 @@ namespace mxnet {
namespace op {
DMLC_REGISTER_PARAMETER(SoftmaxParam);

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
Expand Down Expand Up @@ -114,7 +114,7 @@ Example::
return std::vector<std::string>{"output"};
})
.set_attr<FCompute>("FCompute<cpu>", SoftmaxCompute<cpu, mxnet_op::softmax_fwd>)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", SoftmaxComputeExCPU)
.set_attr<FInferStorageType>("FInferStorageType", SoftmaxStorageType)
Expand Down

0 comments on commit a559760

Please sign in to comment.