Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Support 3D input for MKL-DNN softmax operator (#14818)
Browse files Browse the repository at this point in the history
* add 3d softmax

* fix

* handle req type

* clean code

* remove check

* check axis

* retrigger ci
  • Loading branch information
TaoLv authored and pengzhao-intel committed May 17, 2019
1 parent d87bd2a commit 8d6ac4a
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ bool SupportMKLDNNAct(const ActivationParam& param);
bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input);
bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input);
bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input);
bool SupportMKLDNNSoftmax(const SoftmaxParam& param);
bool SupportMKLDNNSoftmax(const SoftmaxParam& param, const NDArray &input, const NDArray &output);
bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam &param);
bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray &data);
} // namespace op
Expand Down
61 changes: 42 additions & 19 deletions src/operator/nn/mkldnn/mkldnn_softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,43 +26,66 @@
#include "../softmax-inl.h"
#include "./mkldnn_ops-inl.h"
#include "./mkldnn_base-inl.h"
#include "../../tensor/broadcast_reduce_op.h"

#if MXNET_USE_MKLDNN == 1
namespace mxnet {
namespace op {

bool SupportMKLDNNSoftmax(const SoftmaxParam &param) {
bool SupportMKLDNNSoftmax(const SoftmaxParam &param,
const NDArray &data,
const NDArray &output) {
const int ndim = data.shape().ndim();
const int in_dtype = data.dtype();
const int out_dtype = output.dtype();

const int axis = CheckAxis(param.axis, ndim);
// MKLDNN does not support temperature argument in their softmax function
// now. Need update this once they start to support it.
if (param.temperature.has_value()) {
// Currently, MKLDNN shows bad performance when softmax is not performed on the last dimension
if (param.temperature.has_value() ||
in_dtype != mshadow::kFloat32 ||
in_dtype != out_dtype ||
axis != (ndim - 1)) {
return false;
}
return true;
// only supports ndim = 1, 2, 3, 4 for now
return (ndim >= 1 && ndim <= 4);
}

static mkldnn::softmax_forward::primitive_desc GetSoftmaxFwdPd(const int axis,
const bool is_train,
const mkldnn::memory &input) {
auto data_md = input.get_primitive_desc().desc();
auto prop = is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring;
auto desc = mkldnn::softmax_forward::desc(prop, data_md, axis);
auto pd = mkldnn::softmax_forward::primitive_desc(desc, CpuEngine::Get()->get_engine());
return pd;
}

void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,
void MKLDNNSoftmaxForward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &in_data,
const OpReqType &req,
const NDArray &out_data) {
if (req == kNullOp) return;
// same as the FCompute path, softmax only supports kWriteTo and kWriteInplace for now.
CHECK_NE(req, kAddTo);
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
auto input_mem = in_data.GetMKLDNNData();
mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc();
mkldnn::memory::desc data_md = data_mpd.desc();
int axis = CheckAxis(param.axis, in_data.shape().ndim());
const int axis = CheckAxis(param.axis, in_data.shape().ndim());

auto cpu_engine = data_mpd.get_engine();
auto prop = ctx.is_train
? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring;
mkldnn::softmax_forward::desc desc = mkldnn::softmax_forward::desc(prop,
data_md, axis);
mkldnn::softmax_forward::primitive_desc pdesc(desc, cpu_engine);
NDArray data = in_data;
if (in_data.IsView() && in_data.IsMKLDNNData()) {
data = in_data.Reorder2Default();
}

auto output_memory = out_data.GetMKLDNNData();
auto data_mem = data.GetMKLDNNData();
auto pd = GetSoftmaxFwdPd(axis, ctx.is_train, *data_mem);
auto out_mem = CreateMKLDNNMem(out_data, pd.dst_primitive_desc(), req);
MKLDNNStream *stream = MKLDNNStream::Get();
stream->RegisterPrim(mkldnn::softmax_forward(pdesc, *input_mem, *output_memory));
stream->RegisterPrim(mkldnn::softmax_forward(pd, *data_mem, *out_mem.second));
CommitOutput(out_data, out_mem);
stream->Submit();
}

} // namespace op
} // namespace mxnet
#endif
2 changes: 1 addition & 1 deletion src/operator/nn/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& outputs) {
// It seems MKLDNN softmax doesn't support training.
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
if (SupportMKLDNN(inputs[0]) && !ctx.is_train && SupportMKLDNNSoftmax(param)) {
if (SupportMKLDNNSoftmax(param, inputs[0], outputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNSoftmaxForward(attrs, ctx, inputs[0], req[0], outputs[0]);
auto fn = SoftmaxCompute<cpu, mxnet_op::softmax_fwd>;
Expand Down

0 comments on commit 8d6ac4a

Please sign in to comment.