Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Support 3D input for MKL-DNN softmax operator #14818

Merged
merged 13 commits into from
May 17, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ bool SupportMKLDNNAct(const ActivationParam& param);
bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input);
bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input);
bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input);
bool SupportMKLDNNSoftmax(const SoftmaxParam& param);
bool SupportMKLDNNSoftmax(const SoftmaxParam& param, const NDArray &input, const NDArray &output);
bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam &param);
bool SupportMKLDNNTranspose(const TransposeParam& param, const NDArray &data);
} // namespace op
Expand Down
61 changes: 42 additions & 19 deletions src/operator/nn/mkldnn/mkldnn_softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,43 +26,66 @@
#include "../softmax-inl.h"
#include "./mkldnn_ops-inl.h"
#include "./mkldnn_base-inl.h"
#include "../../tensor/broadcast_reduce_op.h"

#if MXNET_USE_MKLDNN == 1
namespace mxnet {
namespace op {

bool SupportMKLDNNSoftmax(const SoftmaxParam &param) {
bool SupportMKLDNNSoftmax(const SoftmaxParam &param,
const NDArray &data,
const NDArray &output) {
const int ndim = data.shape().ndim();
const int in_dtype = data.dtype();
const int out_dtype = output.dtype();

const int axis = CheckAxis(param.axis, ndim);
// MKLDNN does not support temperature argument in their softmax function
// now. Need update this once they start to support it.
if (param.temperature.has_value()) {
// Currently, MKLDNN shows bad performance when softmax is not performed on the last dimension
if (param.temperature.has_value() ||
in_dtype != mshadow::kFloat32 ||
in_dtype != out_dtype ||
axis != (ndim - 1)) {
return false;
}
return true;
// only supports ndim = 1, 2, 3, 4 for now
return (ndim >= 1 && ndim <= 4);
}

static mkldnn::softmax_forward::primitive_desc GetSoftmaxFwdPd(const int axis,
const bool is_train,
const mkldnn::memory &input) {
auto data_md = input.get_primitive_desc().desc();
auto prop = is_train ? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring;
auto desc = mkldnn::softmax_forward::desc(prop, data_md, axis);
auto pd = mkldnn::softmax_forward::primitive_desc(desc, CpuEngine::Get()->get_engine());
return pd;
}

void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,
void MKLDNNSoftmaxForward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const NDArray &in_data,
const OpReqType &req,
const NDArray &out_data) {
if (req == kNullOp) return;
// same as the FCompute path, softmax only supports kWriteTo and kWriteInplace for now.
CHECK_NE(req, kAddTo);
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
auto input_mem = in_data.GetMKLDNNData();
mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc();
mkldnn::memory::desc data_md = data_mpd.desc();
int axis = CheckAxis(param.axis, in_data.shape().ndim());
const int axis = CheckAxis(param.axis, in_data.shape().ndim());

auto cpu_engine = data_mpd.get_engine();
auto prop = ctx.is_train
? mkldnn::prop_kind::forward_training : mkldnn::prop_kind::forward_scoring;
mkldnn::softmax_forward::desc desc = mkldnn::softmax_forward::desc(prop,
data_md, axis);
mkldnn::softmax_forward::primitive_desc pdesc(desc, cpu_engine);
NDArray data = in_data;
if (in_data.IsView() && in_data.IsMKLDNNData()) {
data = in_data.Reorder2Default();
}

auto output_memory = out_data.GetMKLDNNData();
auto data_mem = data.GetMKLDNNData();
auto pd = GetSoftmaxFwdPd(axis, ctx.is_train, *data_mem);
auto out_mem = CreateMKLDNNMem(out_data, pd.dst_primitive_desc(), req);
MKLDNNStream *stream = MKLDNNStream::Get();
stream->RegisterPrim(mkldnn::softmax_forward(pdesc, *input_mem, *output_memory));
stream->RegisterPrim(mkldnn::softmax_forward(pd, *data_mem, *out_mem.second));
CommitOutput(out_data, out_mem);
stream->Submit();
}

} // namespace op
} // namespace mxnet
#endif
2 changes: 1 addition & 1 deletion src/operator/nn/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& outputs) {
// It seems MKLDNN softmax doesn't support training.
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
if (SupportMKLDNN(inputs[0]) && !ctx.is_train && SupportMKLDNNSoftmax(param)) {
if (SupportMKLDNNSoftmax(param, inputs[0], outputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNSoftmaxForward(attrs, ctx, inputs[0], req[0], outputs[0]);
auto fn = SoftmaxCompute<cpu, mxnet_op::softmax_fwd>;
Expand Down