Skip to content

Commit

Permalink
add mkldnn softmax backward (apache#17170)
Browse files Browse the repository at this point in the history
* add mkldnn softmax backward

* add primitive cache for softmax bwd

* fix preci failed test

* rm duplicate line
  • Loading branch information
rongzha1 authored and anirudh2290 committed May 29, 2020
1 parent 25d8ecf commit 17b46c8
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 2 deletions.
4 changes: 4 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,10 @@ void MKLDNNLeakyReluBackward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,
const NDArray &out_data);
void MKLDNNSoftmaxBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data);

/* For softmax_output */
void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
Expand Down
84 changes: 84 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,18 @@ static mkldnn::softmax_forward::primitive_desc GetSoftmaxFwdPd(bool is_train,
return mkldnn::softmax_forward::primitive_desc(desc, cpu_engine);
}

static mkldnn::softmax_backward::primitive_desc GetSoftmaxBwdPd(
const mkldnn::memory &diff_mem,
const mkldnn::memory &data_mem,
const int axis,
const mkldnn::softmax_forward::primitive_desc &hint_fwd_pd) {
mkldnn::memory::desc diff_md = diff_mem.get_desc();
mkldnn::memory::desc data_md = data_mem.get_desc();
auto cpu_engine = CpuEngine::Get()->get_engine();
auto desc = mkldnn::softmax_backward::desc(diff_md, data_md, axis);
return mkldnn::softmax_backward::primitive_desc(desc, cpu_engine, hint_fwd_pd);
}


bool SupportMKLDNNSoftmax(const SoftmaxParam &param,
const NDArray &data,
Expand Down Expand Up @@ -131,6 +143,78 @@ void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs,
stream->Submit();
}

class MKLDNNSoftmaxBwd {
public:
mkldnn::softmax_backward::primitive_desc pd;

MKLDNNSoftmaxBwd(const mkldnn::memory &diff_mem,
const mkldnn::memory &data_mem,
const int axis,
const mkldnn::softmax_forward::primitive_desc &hint_fwd_pd) :
pd(GetSoftmaxBwdPd(diff_mem, data_mem, axis, hint_fwd_pd)) {
bwd_ = std::make_shared<mkldnn::softmax_backward>(pd);
}

const mkldnn::softmax_backward &GetBwd() const {
return *bwd_;
}

private:
std::shared_ptr<mkldnn::softmax_backward> bwd_;
};

static MKLDNNSoftmaxBwd &GetSoftmaxBwd(const SoftmaxParam &param,
const int real_axis,
const std::vector<NDArray> &data,
const std::vector<NDArray> &output) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<MKLDNNSoftmaxSignature, MKLDNNSoftmaxBwd, OpHash> bwds;
#else
static MX_THREAD_LOCAL std::unordered_map<MKLDNNSoftmaxSignature, MKLDNNSoftmaxBwd, OpHash> bwds;
#endif

MKLDNNSoftmaxSignature key(param);
key.AddSign(real_axis);
key.AddSign(data);
key.AddSign(output);

auto it = bwds.find(key);
if (it == bwds.end()) {
auto diff_mem = data[0].GetMKLDNNData();
auto data_mem = data[1].GetMKLDNNData();
auto fwd_pd = GetSoftmaxFwdPd(true, real_axis, *data_mem);
MKLDNNSoftmaxBwd bwd(*diff_mem, *data_mem, real_axis, fwd_pd);
it = AddToCache(&bwds, key, bwd);
}
return it->second;
}

void MKLDNNSoftmaxBackward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data) {
if (req[0] == kNullOp) return;
CHECK_EQ(in_data.size(), 2U);
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
int axis = CheckAxis(param.axis, in_data[1].shape().ndim());
auto diff_mem = in_data[0].GetMKLDNNData();
auto data_mem = in_data[1].GetMKLDNNData();
auto bwd = GetSoftmaxBwd(param, axis, in_data, out_data);

auto out_mem = CreateMKLDNNMem(out_data[0], bwd.pd.diff_src_desc(), req[0]);
MKLDNNStream *stream = MKLDNNStream::Get();
mkldnn_args_map_t args = {
{ MKLDNN_ARG_DST, *data_mem },
{ MKLDNN_ARG_DIFF_DST, *diff_mem },
{ MKLDNN_ARG_DIFF_SRC, *out_mem.second }
};

stream->RegisterPrimArgs(bwd.GetBwd(), args);
CommitOutput(out_data[0], out_mem);
stream->Submit();
}

} // namespace op
} // namespace mxnet
#endif
41 changes: 39 additions & 2 deletions src/operator/nn/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
// It seems MKLDNN softmax doesn't support training.
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
if (SupportMKLDNNSoftmax(param, inputs[0], outputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
Expand All @@ -54,6 +53,23 @@ static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
inputs, req, outputs);
}

static void SoftmaxGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
if (SupportMKLDNNSoftmax(param, inputs[1], outputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
MKLDNNRun(MKLDNNSoftmaxBackward, attrs, ctx, inputs, req, outputs);
auto fn = SoftmaxGradCompute<cpu, op::mshadow_op::mul, mxnet_op::softmax_bwd>;
MKLDNN_OPCHECK_RUN(fn, attrs, ctx, inputs, req, outputs);
return;
}
FallBackCompute(SoftmaxGradCompute<cpu, op::mshadow_op::mul, mxnet_op::softmax_bwd>, attrs, ctx,
inputs, req, outputs);
}

inline static bool SoftmaxStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
Expand All @@ -72,6 +88,23 @@ inline static bool SoftmaxStorageType(const nnvm::NodeAttrs& attrs,
return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs,
out_attrs);
}

inline static bool SoftmaxGradStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
if (param.use_length.value() || softmax_has_dtype_override(attrs)) {
auto& out_stype = out_attrs->at(0);
return storage_type_assign(&out_stype, kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
}
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
return MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs,
out_attrs);
}
#endif


Expand Down Expand Up @@ -147,8 +180,12 @@ NNVM_REGISTER_OP(_backward_softmax)
.set_attr<nnvm::FInplaceOption>("FInplaceOption", SoftmaxGradOpInplaceOption)
.add_argument("args", "NDArray-or-Symbol[]", "Positional input arguments")
.set_attr_parser(ParamParser<SoftmaxParam>)
#if MXNET_USE_MKLDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", SoftmaxGradComputeExCPU)
.set_attr<FInferStorageType>("FInferStorageType", SoftmaxGradStorageType)
#endif
.set_attr<FCompute>("FCompute<cpu>", SoftmaxGradCompute<cpu, op::mshadow_op::mul,
mxnet_op::softmax_bwd>);

} // namespace op
} // namespace mxnet

0 comments on commit 17b46c8

Please sign in to comment.