Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[mkldnn-v1.0] Add MKL-DNN activation #16195

Merged
merged 2 commits into from
Sep 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions src/operator/nn/activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
#include "./activation-inl.h"
#include "../mshadow_op.h"
#include "../tensor/elemwise_unary_op.h"
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include "./mkldnn/mkldnn_base-inl.h"
#include "./mkldnn/mkldnn_ops-inl.h"
#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_USE_MKLDNN == 100
#include "../operator_common.h"
#include "../../common/utils.h"

Expand Down Expand Up @@ -91,7 +91,7 @@ struct ActivationGrad {
}
};

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
static void ActivationComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
Expand Down Expand Up @@ -150,7 +150,7 @@ inline static bool BackwardActStorageType(const nnvm::NodeAttrs& attrs,
return MKLDNNStorageType(attrs, dev_mask, SupportMKLDNNAct(param),
dispatch_mode, in_attrs, out_attrs);
}
#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_USE_MKLDNN == 100


MXNET_OPERATOR_REGISTER_UNARY(Activation)
Expand All @@ -167,15 +167,15 @@ The following activation functions are supported:

)code" ADD_FILELINE)
.set_attr_parser(ParamParser<ActivationParam>)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<FInferStorageType>("FInferStorageType", ActivationStorageType)
#endif
.set_attr<nnvm::FListOutputNames>("FListOutputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"output"};
})
.set_attr<FCompute>("FCompute<cpu>", ActivationCompute<cpu>)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", ActivationComputeExCPU)
#endif
Expand All @@ -189,21 +189,21 @@ NNVM_REGISTER_OP(_backward_Activation)
})
.set_num_outputs(1)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<FInferStorageType>("FInferStorageType", BackwardActStorageType)
#endif
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<-1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}};
})
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
#endif
.set_attr_parser(ParamParser<ActivationParam>)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", ActivationGradComputeExCPU)
#endif
Expand Down
40 changes: 28 additions & 12 deletions src/operator/nn/mkldnn/mkldnn_act-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,43 +20,39 @@
/*!
* Copyright (c) 2019 by Contributors
* \file mkldnn_act-inl.h
* \brief MKLDNN(Quantized) Activation operator based on subgraph
* \brief MKLDNN Activation operator
* /author Zhiyuan Huang
*/

#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_ACT_INL_H_
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_ACT_INL_H_


#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include <vector>
#include <utility>
#include "../activation-inl.h"
#include "./mkldnn_ops-inl.h"
#include "./mkldnn_base-inl.h"

namespace mxnet {
namespace op {

mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param);
mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl(
const ActivationParam& param, bool is_train,
const mkldnn::memory &input_mem, int dtype);
const ActivationParam& param, bool is_train, const mkldnn::memory &input_mem);

class MKLDNNActForward {
public:
const mkldnn::eltwise_forward::primitive_desc fwd_pd;

MKLDNNActForward(const ActivationParam& param, bool is_train,
const NDArray &data, const mkldnn::memory &mem): fwd_pd(
GetActFwdDescImpl(param, is_train, mem, data.dtype())) {}
void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output);
const mkldnn::eltwise_forward &GetFwd() const;
GetActFwdDescImpl(param, is_train, mem)) {
fwd_ = std::make_shared<mkldnn::eltwise_forward>(fwd_pd);
}
const inline mkldnn::eltwise_forward &GetFwd() const;

private:
std::shared_ptr<mkldnn::eltwise_forward> fwd_;
std::shared_ptr<mkldnn::memory> data_;
std::shared_ptr<mkldnn::memory> out_;
};

typedef ParamOpSign<ActivationParam> MKLDNNActSignature;
Expand All @@ -67,8 +63,28 @@ MKLDNNActForward &GetActForward(const ActivationParam& param,
void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,
const NDArray &out_data);

mkldnn::eltwise_backward::primitive_desc GetActBwdDescImpl(
const ActivationParam &param, const mkldnn::memory &input_mem,
const mkldnn::memory &diff_dst_memory);

class MKLDNNActBackward {
public:
const mkldnn::eltwise_backward::primitive_desc pd;

explicit MKLDNNActBackward(const ActivationParam &param, const NDArray &data,
const mkldnn::memory &mem,
const mkldnn::memory &diff_dst_memory): pd(
GetActBwdDescImpl(param, mem, diff_dst_memory)) {
bwd = std::make_shared<mkldnn::eltwise_backward>(pd);
}
const inline mkldnn::eltwise_backward &GetBwd() const;

private:
std::shared_ptr<mkldnn::eltwise_backward> bwd;
};
} // namespace op
} // namespace mxnet

#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_USE_MKLDNN == 100
#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_ACT_INL_H_
130 changes: 31 additions & 99 deletions src/operator/nn/mkldnn/mkldnn_act.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
* \author Da Zheng
*/

#if MXNET_USE_MKLDNN == 100

#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
Expand All @@ -33,10 +35,7 @@
#include <utility>
#include "../../operator_common.h"
#include "mkldnn_act-inl.h"

#if MXNET_USE_MKLDNN == 1

#include <mkldnn.hpp>
#include "./mkldnn_base-inl.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -81,41 +80,19 @@ mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param) {

mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl(
const ActivationParam& param, bool is_train,
const mkldnn::memory &input_mem, int dtype) {
mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc();
mkldnn::memory::desc data_md = data_mpd.desc();
auto cpu_engine = data_mpd.get_engine();

const mkldnn::memory &input_mem) {
mkldnn::memory::desc data_md = input_mem.get_desc();
auto cpu_engine = CpuEngine::Get()->get_engine();
auto alg = GetMKLDNNActAlgo(param);

auto prop = is_train ? mkldnn::prop_kind::forward_training :
mkldnn::prop_kind::forward_scoring;
auto desc = mkldnn::eltwise_forward::desc(prop, alg, data_md, 0.0f);
return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine);
}

void MKLDNNActForward::SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output) {
if (this->data_ == nullptr)
this->data_ = std::make_shared<mkldnn::memory>(data.get_primitive_desc(),
data.get_data_handle());
else
this->data_->set_data_handle(data.get_data_handle());

CHECK(fwd_pd.dst_primitive_desc() == output.get_primitive_desc());
if (this->out_ == nullptr)
this->out_ = std::make_shared<mkldnn::memory>(fwd_pd.dst_primitive_desc(),
output.get_data_handle());
else
this->out_->set_data_handle(output.get_data_handle());

if (this->fwd_ == nullptr) {
this->fwd_ = std::shared_ptr<mkldnn::eltwise_forward>(
new mkldnn::eltwise_forward(fwd_pd, mkldnn::primitive::at(*this->data_),
*this->out_));
}
return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine);
}

const mkldnn::eltwise_forward &MKLDNNActForward::GetFwd() const {
const inline mkldnn::eltwise_forward &MKLDNNActForward::GetFwd() const {
return *fwd_;
}

Expand All @@ -131,7 +108,6 @@ MKLDNNActForward &GetActForward(const ActivationParam& param,
key.AddSign(ctx.is_train);
key.AddSign(param.act_type);
key.AddSign(in_data);

auto it = fwds.find(key);
if (it == fwds.end()) {
MKLDNNActForward fwd(param, ctx.is_train, in_data, in_mem);
Expand All @@ -153,81 +129,34 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,

auto input_mem = in_buffer.GetMKLDNNData();
MKLDNNActForward &fwd = GetActForward(param, ctx, in_buffer, *input_mem);
auto out_mem_t = CreateMKLDNNMem(out_data, fwd.fwd_pd.dst_primitive_desc(), req, &in_buffer);
fwd.SetNewMem(*input_mem, *out_mem_t.second);
stream->RegisterPrim(fwd.GetFwd());
auto out_mem_t = CreateMKLDNNMem(out_data, fwd.fwd_pd.dst_desc(), req, &in_buffer);
stream->RegisterPrimArgs(fwd.GetFwd(),
{{ MKLDNN_ARG_SRC, *input_mem}, { MKLDNN_ARG_DST, *out_mem_t.second}});
CommitOutput(out_data, out_mem_t);
stream->Submit();
}

static mkldnn::eltwise_backward::primitive_desc GetActBwdDescImpl(
mkldnn::eltwise_backward::primitive_desc GetActBwdDescImpl(
const ActivationParam &param, const mkldnn::memory &input_mem,
const mkldnn::memory &diff_dst_memory, int dtype) {
mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc();
mkldnn::memory::desc data_md = data_mpd.desc();
mkldnn::memory::desc diff_md = diff_dst_memory.get_primitive_desc().desc();
auto cpu_engine = data_mpd.get_engine();
const mkldnn::memory &diff_dst_memory) {
mkldnn::memory::desc data_md = input_mem.get_desc();
mkldnn::memory::desc diff_md = diff_dst_memory.get_desc();
auto cpu_engine = CpuEngine::Get()->get_engine();
auto alg = GetMKLDNNActAlgo(param);

MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
DType alpha = 0;
mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training,
alg, data_md, alpha);
mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine);
mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, alpha);
mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine,
fw_pdesc);
return bw_pdesc;
});
LOG(FATAL) << "Unsupported data type for MKLDNN activation";
float alpha = 0;
mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training,
alg, data_md, 0.0);
alg, data_md, alpha);
mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine);
mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, 0.0);
mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, alpha);
mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine,
fw_pdesc);
return bw_pdesc;
}

class MKLDNNActBackward {
std::shared_ptr<mkldnn::eltwise_backward> bwd;
std::shared_ptr<mkldnn::memory> data;
std::shared_ptr<mkldnn::memory> diff_dst_memory;
std::shared_ptr<mkldnn::memory> diff_src_memory;

public:
const mkldnn::eltwise_backward::primitive_desc pd;

explicit MKLDNNActBackward(const ActivationParam &param, const NDArray &data,
const mkldnn::memory &mem,
const mkldnn::memory &diff_dst_memory)
: pd(GetActBwdDescImpl(param, mem, diff_dst_memory, data.dtype())) {}

void SetNewMem(const mkldnn::memory &data,
const mkldnn::memory &diff_dst_memory,
const mkldnn::memory &diff_src_memory) {
if (this->bwd != nullptr) {
this->data->set_data_handle(data.get_data_handle());
this->diff_dst_memory->set_data_handle(diff_dst_memory.get_data_handle());
this->diff_src_memory->set_data_handle(diff_src_memory.get_data_handle());
} else {
this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
data.get_primitive_desc(), data.get_data_handle()));
this->diff_dst_memory = std::shared_ptr<mkldnn::memory>(
new mkldnn::memory(diff_dst_memory.get_primitive_desc(),
diff_dst_memory.get_data_handle()));
this->diff_src_memory = std::shared_ptr<mkldnn::memory>(
new mkldnn::memory(diff_src_memory.get_primitive_desc(),
diff_src_memory.get_data_handle()));
this->bwd = std::shared_ptr<mkldnn::eltwise_backward>(
new mkldnn::eltwise_backward(
this->pd, mkldnn::primitive::at(*this->data),
*this->diff_dst_memory, *this->diff_src_memory));
}
}

const inline mkldnn::eltwise_backward &GetBwd() const { return *bwd; }
};
const inline mkldnn::eltwise_backward &MKLDNNActBackward::GetBwd() const {
return *bwd;
}

static inline MKLDNNActBackward &GetActBackward(const ActivationParam &param,
const OpContext &ctx,
Expand Down Expand Up @@ -274,20 +203,23 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx
auto input_mem = in_buffer.GetMKLDNNData();
// We need to make sure the two inputs to eltwise_backward has the same memory
// descriptor. Otherwise, the perf will suffer.
if (input_mem->get_primitive_desc() != diff_dst_memory->get_primitive_desc())
input_mem = in_buffer.GetMKLDNNDataReorder(diff_dst_memory->get_primitive_desc());
if (input_mem->get_desc() != diff_dst_memory->get_desc())
input_mem = in_buffer.GetMKLDNNDataReorder(diff_dst_memory->get_desc());
MKLDNNActBackward &bwd =
GetActBackward(param, ctx, in_buffer, out_buffer, *input_mem);
MKLDNNStream *stream = MKLDNNStream::Get();
mkldnn_output_t diff_src_memory =
CreateMKLDNNMem(in_grad, bwd.pd.diff_src_primitive_desc(), req);
bwd.SetNewMem(*input_mem, *diff_dst_memory, *diff_src_memory.second);
stream->RegisterPrim(bwd.GetBwd());
CreateMKLDNNMem(in_grad, bwd.pd.diff_src_desc(), req);
mkldnn_args_map_t args = {
{ MKLDNN_ARG_SRC, *input_mem },
{ MKLDNN_ARG_DIFF_DST, *diff_dst_memory },
{ MKLDNN_ARG_DIFF_SRC, *diff_src_memory.second },
};
stream->RegisterPrimArgs(bwd.GetBwd(), args);
CommitOutput(in_grad, diff_src_memory);
stream->Submit();
}

} // namespace op
} // namespace mxnet

#endif
15 changes: 7 additions & 8 deletions src/operator/nn/mkldnn/mkldnn_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,6 @@ void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);

/* For activation */
void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,
const NDArray &out_data);
void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &out_grad, const NDArray &in_data,
const OpReqType &req, const NDArray &in_grad);

void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const NDArray &data,
Expand Down Expand Up @@ -133,6 +125,13 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);

/* For activation */
void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,
const NDArray &out_data);
void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &out_grad, const NDArray &in_data,
const OpReqType &req, const NDArray &in_grad);

void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2,
const mkldnn::memory &out);
Expand Down