diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h b/src/operator/nn/mkldnn/mkldnn_base-inl.h index 3da3f23d7683..b05157942131 100644 --- a/src/operator/nn/mkldnn/mkldnn_base-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h @@ -180,7 +180,7 @@ bool SupportMKLDNNAct(const ActivationParam& param); bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input); bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input); bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input); -bool SupportMKLDNNSoftmax(const SoftmaxParam& param); +bool SupportMKLDNNSoftmax(const SoftmaxParam& param, const NDArray &input, const NDArray &output); bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam ¶m); bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray &data); } // namespace op diff --git a/src/operator/nn/mkldnn/mkldnn_softmax.cc b/src/operator/nn/mkldnn/mkldnn_softmax.cc index 7268ed39339e..77ab43b63fd5 100644 --- a/src/operator/nn/mkldnn/mkldnn_softmax.cc +++ b/src/operator/nn/mkldnn/mkldnn_softmax.cc @@ -26,43 +26,66 @@ #include "../softmax-inl.h" #include "./mkldnn_ops-inl.h" #include "./mkldnn_base-inl.h" -#include "../../tensor/broadcast_reduce_op.h" #if MXNET_USE_MKLDNN == 1 namespace mxnet { namespace op { -bool SupportMKLDNNSoftmax(const SoftmaxParam ¶m) { +bool SupportMKLDNNSoftmax(const SoftmaxParam ¶m, + const NDArray &data, + const NDArray &output) { + const int ndim = data.shape().ndim(); + const int in_dtype = data.dtype(); + const int out_dtype = output.dtype(); + + const int axis = CheckAxis(param.axis, ndim); // MKLDNN does not support temperature argument in their softmax function // now. Need update this once they start to support it. - if (param.temperature.has_value()) { + // Currently, MKLDNN shows bad performance when softmax is not performed on the last dimension + if (param.temperature.has_value() || + in_dtype != mshadow::kFloat32 || + in_dtype != out_dtype || + axis != (ndim - 1)) { return false; } - return true; + // only supports ndim = 1, 2, 3, 4 for now + return (ndim >= 1 && ndim <= 4); +} + +static mkldnn::softmax_forward::primitive_desc GetSoftmaxFwdPd(const int axis, + const bool is_train, + const mkldnn::memory &input) { + auto data_md = input.get_primitive_desc().desc(); + auto prop = is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring; + auto desc = mkldnn::softmax_forward::desc(prop, data_md, axis); + auto pd = mkldnn::softmax_forward::primitive_desc(desc, CpuEngine::Get()->get_engine()); + return pd; } -void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, - const NDArray &in_data, const OpReqType &req, +void MKLDNNSoftmaxForward(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const NDArray &in_data, + const OpReqType &req, const NDArray &out_data) { + if (req == kNullOp) return; + // same as the FCompute path, softmax only supports kWriteTo and kWriteInplace for now. + CHECK_NE(req, kAddTo); const SoftmaxParam& param = nnvm::get(attrs.parsed); - auto input_mem = in_data.GetMKLDNNData(); - mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc(); - mkldnn::memory::desc data_md = data_mpd.desc(); - int axis = CheckAxis(param.axis, in_data.shape().ndim()); + const int axis = CheckAxis(param.axis, in_data.shape().ndim()); - auto cpu_engine = data_mpd.get_engine(); - auto prop = ctx.is_train - ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring; - mkldnn::softmax_forward::desc desc = mkldnn::softmax_forward::desc(prop, - data_md, axis); - mkldnn::softmax_forward::primitive_desc pdesc(desc, cpu_engine); + NDArray data = in_data; + if (in_data.IsView() && in_data.IsMKLDNNData()) { + data = in_data.Reorder2Default(); + } - auto output_memory = out_data.GetMKLDNNData(); + auto data_mem = data.GetMKLDNNData(); + auto pd = GetSoftmaxFwdPd(axis, ctx.is_train, *data_mem); + auto out_mem = CreateMKLDNNMem(out_data, pd.dst_primitive_desc(), req); MKLDNNStream *stream = MKLDNNStream::Get(); - stream->RegisterPrim(mkldnn::softmax_forward(pdesc, *input_mem, *output_memory)); + stream->RegisterPrim(mkldnn::softmax_forward(pd, *data_mem, *out_mem.second)); + CommitOutput(out_data, out_mem); stream->Submit(); } - } // namespace op } // namespace mxnet #endif diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc index b84dd93300f8..e44bbbb6b8f6 100644 --- a/src/operator/nn/softmax.cc +++ b/src/operator/nn/softmax.cc @@ -43,7 +43,7 @@ static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& outputs) { // It seems MKLDNN softmax doesn't support training. const SoftmaxParam& param = nnvm::get(attrs.parsed); - if (SupportMKLDNN(inputs[0]) && !ctx.is_train && SupportMKLDNNSoftmax(param)) { + if (SupportMKLDNNSoftmax(param, inputs[0], outputs[0])) { MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs); MKLDNNSoftmaxForward(attrs, ctx, inputs[0], req[0], outputs[0]); auto fn = SoftmaxCompute;