Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[mkldnn-v1.0] Add MKL-DNN softmax #16246

Merged
merged 2 commits into from
Sep 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/operator/nn/mkldnn/mkldnn_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,6 @@ void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);

/* For softmax */
void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,
const NDArray &out_data);

/* For softmax_output */
void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
Expand Down Expand Up @@ -133,6 +128,11 @@ void MKLDNNLeakyReluBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray>& inputs, const OpReqType &req,
const NDArray &output);

/* For softmax */
void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,
const NDArray &out_data);

void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2,
const mkldnn::memory &out);

Expand Down
42 changes: 24 additions & 18 deletions src/operator/nn/mkldnn/mkldnn_softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,30 @@
#include "./mkldnn_ops-inl.h"
#include "./mkldnn_base-inl.h"

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
namespace mxnet {
namespace op {

static mkldnn::softmax_forward::primitive_desc GetSoftmaxFwdPd(
bool is_train, const int axis,
const mkldnn::memory &input_mem) {
mkldnn::memory::desc data_md = input_mem.get_desc();
auto cpu_engine = CpuEngine::Get()->get_engine();
auto prop = is_train ? mkldnn::prop_kind::forward_training
: mkldnn::prop_kind::forward_scoring;
auto desc = mkldnn::softmax_forward::desc(prop, data_md, axis);
return mkldnn::softmax_forward::primitive_desc(desc, cpu_engine);
}


bool SupportMKLDNNSoftmax(const SoftmaxParam &param,
const NDArray &data,
const NDArray &output) {
// MKLDNN does not support temperature argument in their softmax function
// now. Need update this once they start to support it.
const int ndim = data.shape().ndim();
const int in_dtype = data.dtype();
const int out_dtype = output.dtype();

const int axis = CheckAxis(param.axis, ndim);
// MKLDNN does not support temperature argument in their softmax function
// now. Need update this once they start to support it.
Expand All @@ -48,21 +61,12 @@ bool SupportMKLDNNSoftmax(const SoftmaxParam &param,
axis != (ndim - 1)) {
return false;
}

// only supports ndim = 1, 2, 3, 4 for now
return (ndim >= 1 && ndim <= 4);
}

static mkldnn::softmax_forward::primitive_desc GetSoftmaxFwdPd(const int axis,
const bool is_train,
const mkldnn::memory &input) {
auto data_md = input.get_primitive_desc().desc();
auto prop = is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring;
auto desc = mkldnn::softmax_forward::desc(prop, data_md, axis);
auto pd = mkldnn::softmax_forward::primitive_desc(desc, CpuEngine::Get()->get_engine());
return pd;
}

void MKLDNNSoftmaxForward(const nnvm::NodeAttrs &attrs,
void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const NDArray &in_data,
const OpReqType &req,
Expand All @@ -71,21 +75,23 @@ void MKLDNNSoftmaxForward(const nnvm::NodeAttrs &attrs,
// same as the FCompute path, softmax only supports kWriteTo and kWriteInplace for now.
CHECK_NE(req, kAddTo);
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
const int axis = CheckAxis(param.axis, in_data.shape().ndim());

int axis = CheckAxis(param.axis, in_data.shape().ndim());
NDArray data = in_data;
if (in_data.IsView() && in_data.IsMKLDNNData()) {
data = in_data.Reorder2Default();
}

auto data_mem = data.GetMKLDNNData();
auto pd = GetSoftmaxFwdPd(axis, ctx.is_train, *data_mem);
auto out_mem = CreateMKLDNNMem(out_data, pd.dst_primitive_desc(), req);
auto pd = GetSoftmaxFwdPd(ctx.is_train, axis, *data_mem);
auto out_mem = CreateMKLDNNMem(out_data, pd.dst_desc(), req);
MKLDNNStream *stream = MKLDNNStream::Get();
stream->RegisterPrim(mkldnn::softmax_forward(pd, *data_mem, *out_mem.second));
stream->RegisterPrimArgs(pd,
{{MKLDNN_ARG_SRC, *data_mem}, {MKLDNN_ARG_DST, *out_mem.second}});
CommitOutput(out_data, out_mem);
stream->Submit();
}

} // namespace op
} // namespace mxnet
#endif

6 changes: 3 additions & 3 deletions src/operator/nn/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
#include "../tensor/elemwise_unary_op.h"
#include "../tensor/elemwise_binary_op.h"
#include "../operator_common.h"
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include "mkldnn/mkldnn_base-inl.h"
#include "mkldnn/mkldnn_ops-inl.h"
#endif
Expand All @@ -35,7 +35,7 @@ namespace mxnet {
namespace op {
DMLC_REGISTER_PARAMETER(SoftmaxParam);

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
Expand Down Expand Up @@ -114,7 +114,7 @@ Example::
return std::vector<std::string>{"output"};
})
.set_attr<FCompute>("FCompute<cpu>", SoftmaxCompute<cpu, mxnet_op::softmax_fwd>)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", SoftmaxComputeExCPU)
.set_attr<FInferStorageType>("FInferStorageType", SoftmaxStorageType)
Expand Down