Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[mkldnn-v1.0] Add MKL-DNN softmax_output #16222

Merged
merged 1 commit into from
Oct 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/operator/nn/mkldnn/mkldnn_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,6 @@ namespace mxnet {
namespace op {

#if MXNET_USE_MKLDNN == 1
/* For softmax_output */
void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data);

/* For sum */
void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &inputs, const OpReqType &req,
Expand Down Expand Up @@ -121,6 +115,7 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &out_grad, const NDArray &in_data,
const OpReqType &req, const NDArray &in_grad);

void MKLDNNLeakyReluForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,
const NDArray &out_data);
Expand All @@ -133,6 +128,12 @@ void MKLDNNSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,
const NDArray &out_data);

/* For softmax_output */
void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data);

void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2,
const mkldnn::memory &out);

Expand Down
39 changes: 8 additions & 31 deletions src/operator/nn/mkldnn/mkldnn_softmax_output.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,17 @@
* \author Zhang Rong A
*/

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include "../../softmax_output-inl.h"
#include "./mkldnn_ops-inl.h"
#include "./mkldnn_base-inl.h"

namespace mxnet {
namespace op {

static mkldnn::softmax_forward::primitive_desc GetSoftmaxOutputFwdDescImpl(
const SoftmaxOutputParam& param, bool is_train,
const int axis, const mkldnn::memory &input_mem) {
mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc();
mkldnn::memory::desc data_md = data_mpd.desc();
mkldnn::memory::desc data_md = input_mem.get_desc();
auto cpu_engine = CpuEngine::Get()->get_engine();
auto prop = is_train ? mkldnn::prop_kind::forward_training
: mkldnn::prop_kind::forward_scoring;
Expand All @@ -47,38 +45,17 @@ typedef ParamOpSign<SoftmaxOutputParam> MKLDNNSoftmaxOuputSignature;

class MKLDNNSoftmaxOutputFwd {
std::shared_ptr<mkldnn::softmax_forward> fwd_;
std::shared_ptr<mkldnn::memory> data_;
std::shared_ptr<mkldnn::memory> out_;

public:
const mkldnn::softmax_forward::primitive_desc fwd_pd;

MKLDNNSoftmaxOutputFwd(const SoftmaxOutputParam& param, bool is_train,
const int axis, const mkldnn::memory &mem): fwd_pd(
GetSoftmaxOutputFwdDescImpl(param, is_train, axis, mem)) {
fwd_ = std::make_shared<mkldnn::softmax_forward>(fwd_pd);
}

void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output) {
if (this->data_ == nullptr)
this->data_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
data.get_primitive_desc(), data.get_data_handle()));
else
this->data_->set_data_handle(data.get_data_handle());

if (this->out_ == nullptr)
this->out_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
output.get_primitive_desc(), output.get_data_handle()));
else
this->out_->set_data_handle(output.get_data_handle());

if (this->fwd_ == nullptr) {
this->fwd_ = std::shared_ptr<mkldnn::softmax_forward>(
new mkldnn::softmax_forward(fwd_pd, mkldnn::primitive::at(*this->data_),
*this->out_));
}
}

const mkldnn::softmax_forward &GetFwd() const {
const inline mkldnn::softmax_forward &GetFwd() const {
return *fwd_;
}
};
Expand Down Expand Up @@ -129,17 +106,17 @@ void MKLDNNSoftmaxOutputForward(const nnvm::NodeAttrs& attrs,

auto input_mem = idata.GetMKLDNNData();
auto out_mem = CreateMKLDNNMem(out_data[softmaxout_enum::kOut],
input_mem->get_primitive_desc(), req[softmaxout_enum::kOut]);
input_mem->get_desc(), req[softmaxout_enum::kOut]);

MKLDNNSoftmaxOutputFwd &fwd = GetSoftmaxOutputForward(param, ctx, idata);
fwd.SetNewMem(*input_mem, *out_mem.second);

MKLDNNStream *stream = MKLDNNStream::Get();
stream->RegisterPrim(fwd.GetFwd());

stream->RegisterPrimArgs(fwd.GetFwd(),
{{MKLDNN_ARG_SRC, *input_mem}, {MKLDNN_ARG_DST, *out_mem.second}});
CommitOutput(out_data[softmaxout_enum::kOut], out_mem);
stream->Submit();
}
} // namespace op
} // namespace mxnet
#endif

7 changes: 4 additions & 3 deletions src/operator/softmax_output.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
* \author Bing Xu, Zhang Rong A
*/
#include "./softmax_output-inl.h"
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include "./nn/mkldnn/mkldnn_ops-inl.h"
#include "./nn/mkldnn/mkldnn_base-inl.h"
#endif
namespace mxnet {
namespace op {
Expand Down Expand Up @@ -121,7 +122,7 @@ static bool SoftmaxOutputShape(const nnvm::NodeAttrs& attrs,
return true;
}

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
inline static bool SoftmaxOutputStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
Expand Down Expand Up @@ -231,7 +232,7 @@ NNVM_REGISTER_OP(SoftmaxOutput)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr_parser(ParamParser<SoftmaxOutputParam>)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<FInferStorageType>("FInferStorageType", SoftmaxOutputStorageType)
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", SoftmaxOutputComputeExCPU)
Expand Down