Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[mkldnn-v1.0] Add MKL-DNN activation (#16195)
Browse files Browse the repository at this point in the history
* add mkldnn act; pass lint; pass mnist training

* make bwd as private member
  • Loading branch information
rongzha1 authored and TaoLv committed Sep 20, 2019
1 parent 99145a5 commit f930baa
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 128 deletions.
18 changes: 9 additions & 9 deletions src/operator/nn/activation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
#include "./activation-inl.h"
#include "../mshadow_op.h"
#include "../tensor/elemwise_unary_op.h"
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include "./mkldnn/mkldnn_base-inl.h"
#include "./mkldnn/mkldnn_ops-inl.h"
#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_USE_MKLDNN == 100
#include "../operator_common.h"
#include "../../common/utils.h"

Expand Down Expand Up @@ -91,7 +91,7 @@ struct ActivationGrad {
}
};

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
static void ActivationComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
Expand Down Expand Up @@ -150,7 +150,7 @@ inline static bool BackwardActStorageType(const nnvm::NodeAttrs& attrs,
return MKLDNNStorageType(attrs, dev_mask, SupportMKLDNNAct(param),
dispatch_mode, in_attrs, out_attrs);
}
#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_USE_MKLDNN == 100


MXNET_OPERATOR_REGISTER_UNARY(Activation)
Expand All @@ -167,15 +167,15 @@ The following activation functions are supported:
)code" ADD_FILELINE)
.set_attr_parser(ParamParser<ActivationParam>)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<FInferStorageType>("FInferStorageType", ActivationStorageType)
#endif
.set_attr<nnvm::FListOutputNames>("FListOutputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"output"};
})
.set_attr<FCompute>("FCompute<cpu>", ActivationCompute<cpu>)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", ActivationComputeExCPU)
#endif
Expand All @@ -189,21 +189,21 @@ NNVM_REGISTER_OP(_backward_Activation)
})
.set_num_outputs(1)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<FInferStorageType>("FInferStorageType", BackwardActStorageType)
#endif
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<-1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<-1, 1>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}};
})
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
#endif
.set_attr_parser(ParamParser<ActivationParam>)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", ActivationGradComputeExCPU)
#endif
Expand Down
40 changes: 28 additions & 12 deletions src/operator/nn/mkldnn/mkldnn_act-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,43 +20,39 @@
/*!
* Copyright (c) 2019 by Contributors
* \file mkldnn_act-inl.h
* \brief MKLDNN(Quantized) Activation operator based on subgraph
* \brief MKLDNN Activation operator
* /author Zhiyuan Huang
*/

#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_ACT_INL_H_
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_ACT_INL_H_


#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include <vector>
#include <utility>
#include "../activation-inl.h"
#include "./mkldnn_ops-inl.h"
#include "./mkldnn_base-inl.h"

namespace mxnet {
namespace op {

mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param);
mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl(
const ActivationParam& param, bool is_train,
const mkldnn::memory &input_mem, int dtype);
const ActivationParam& param, bool is_train, const mkldnn::memory &input_mem);

class MKLDNNActForward {
public:
const mkldnn::eltwise_forward::primitive_desc fwd_pd;

MKLDNNActForward(const ActivationParam& param, bool is_train,
const NDArray &data, const mkldnn::memory &mem): fwd_pd(
GetActFwdDescImpl(param, is_train, mem, data.dtype())) {}
void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output);
const mkldnn::eltwise_forward &GetFwd() const;
GetActFwdDescImpl(param, is_train, mem)) {
fwd_ = std::make_shared<mkldnn::eltwise_forward>(fwd_pd);
}
const inline mkldnn::eltwise_forward &GetFwd() const;

private:
std::shared_ptr<mkldnn::eltwise_forward> fwd_;
std::shared_ptr<mkldnn::memory> data_;
std::shared_ptr<mkldnn::memory> out_;
};

typedef ParamOpSign<ActivationParam> MKLDNNActSignature;
Expand All @@ -67,8 +63,28 @@ MKLDNNActForward &GetActForward(const ActivationParam& param,
void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,
const NDArray &out_data);

mkldnn::eltwise_backward::primitive_desc GetActBwdDescImpl(
const ActivationParam &param, const mkldnn::memory &input_mem,
const mkldnn::memory &diff_dst_memory);

class MKLDNNActBackward {
public:
const mkldnn::eltwise_backward::primitive_desc pd;

explicit MKLDNNActBackward(const ActivationParam &param, const NDArray &data,
const mkldnn::memory &mem,
const mkldnn::memory &diff_dst_memory): pd(
GetActBwdDescImpl(param, mem, diff_dst_memory)) {
bwd = std::make_shared<mkldnn::eltwise_backward>(pd);
}
const inline mkldnn::eltwise_backward &GetBwd() const;

private:
std::shared_ptr<mkldnn::eltwise_backward> bwd;
};
} // namespace op
} // namespace mxnet

#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_USE_MKLDNN == 100
#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_ACT_INL_H_
130 changes: 31 additions & 99 deletions src/operator/nn/mkldnn/mkldnn_act.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
* \author Da Zheng
*/

#if MXNET_USE_MKLDNN == 100

#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
Expand All @@ -33,10 +35,7 @@
#include <utility>
#include "../../operator_common.h"
#include "mkldnn_act-inl.h"

#if MXNET_USE_MKLDNN == 1

#include <mkldnn.hpp>
#include "./mkldnn_base-inl.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -81,41 +80,19 @@ mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param) {

mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl(
const ActivationParam& param, bool is_train,
const mkldnn::memory &input_mem, int dtype) {
mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc();
mkldnn::memory::desc data_md = data_mpd.desc();
auto cpu_engine = data_mpd.get_engine();

const mkldnn::memory &input_mem) {
mkldnn::memory::desc data_md = input_mem.get_desc();
auto cpu_engine = CpuEngine::Get()->get_engine();
auto alg = GetMKLDNNActAlgo(param);

auto prop = is_train ? mkldnn::prop_kind::forward_training :
mkldnn::prop_kind::forward_scoring;
auto desc = mkldnn::eltwise_forward::desc(prop, alg, data_md, 0.0f);
return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine);
}

void MKLDNNActForward::SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output) {
if (this->data_ == nullptr)
this->data_ = std::make_shared<mkldnn::memory>(data.get_primitive_desc(),
data.get_data_handle());
else
this->data_->set_data_handle(data.get_data_handle());

CHECK(fwd_pd.dst_primitive_desc() == output.get_primitive_desc());
if (this->out_ == nullptr)
this->out_ = std::make_shared<mkldnn::memory>(fwd_pd.dst_primitive_desc(),
output.get_data_handle());
else
this->out_->set_data_handle(output.get_data_handle());

if (this->fwd_ == nullptr) {
this->fwd_ = std::shared_ptr<mkldnn::eltwise_forward>(
new mkldnn::eltwise_forward(fwd_pd, mkldnn::primitive::at(*this->data_),
*this->out_));
}
return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine);
}

const mkldnn::eltwise_forward &MKLDNNActForward::GetFwd() const {
const inline mkldnn::eltwise_forward &MKLDNNActForward::GetFwd() const {
return *fwd_;
}

Expand All @@ -131,7 +108,6 @@ MKLDNNActForward &GetActForward(const ActivationParam& param,
key.AddSign(ctx.is_train);
key.AddSign(param.act_type);
key.AddSign(in_data);

auto it = fwds.find(key);
if (it == fwds.end()) {
MKLDNNActForward fwd(param, ctx.is_train, in_data, in_mem);
Expand All @@ -153,81 +129,34 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,

auto input_mem = in_buffer.GetMKLDNNData();
MKLDNNActForward &fwd = GetActForward(param, ctx, in_buffer, *input_mem);
auto out_mem_t = CreateMKLDNNMem(out_data, fwd.fwd_pd.dst_primitive_desc(), req, &in_buffer);
fwd.SetNewMem(*input_mem, *out_mem_t.second);
stream->RegisterPrim(fwd.GetFwd());
auto out_mem_t = CreateMKLDNNMem(out_data, fwd.fwd_pd.dst_desc(), req, &in_buffer);
stream->RegisterPrimArgs(fwd.GetFwd(),
{{ MKLDNN_ARG_SRC, *input_mem}, { MKLDNN_ARG_DST, *out_mem_t.second}});
CommitOutput(out_data, out_mem_t);
stream->Submit();
}

static mkldnn::eltwise_backward::primitive_desc GetActBwdDescImpl(
mkldnn::eltwise_backward::primitive_desc GetActBwdDescImpl(
const ActivationParam &param, const mkldnn::memory &input_mem,
const mkldnn::memory &diff_dst_memory, int dtype) {
mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc();
mkldnn::memory::desc data_md = data_mpd.desc();
mkldnn::memory::desc diff_md = diff_dst_memory.get_primitive_desc().desc();
auto cpu_engine = data_mpd.get_engine();
const mkldnn::memory &diff_dst_memory) {
mkldnn::memory::desc data_md = input_mem.get_desc();
mkldnn::memory::desc diff_md = diff_dst_memory.get_desc();
auto cpu_engine = CpuEngine::Get()->get_engine();
auto alg = GetMKLDNNActAlgo(param);

MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
DType alpha = 0;
mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training,
alg, data_md, alpha);
mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine);
mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, alpha);
mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine,
fw_pdesc);
return bw_pdesc;
});
LOG(FATAL) << "Unsupported data type for MKLDNN activation";
float alpha = 0;
mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training,
alg, data_md, 0.0);
alg, data_md, alpha);
mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine);
mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, 0.0);
mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, alpha);
mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine,
fw_pdesc);
return bw_pdesc;
}

class MKLDNNActBackward {
std::shared_ptr<mkldnn::eltwise_backward> bwd;
std::shared_ptr<mkldnn::memory> data;
std::shared_ptr<mkldnn::memory> diff_dst_memory;
std::shared_ptr<mkldnn::memory> diff_src_memory;

public:
const mkldnn::eltwise_backward::primitive_desc pd;

explicit MKLDNNActBackward(const ActivationParam &param, const NDArray &data,
const mkldnn::memory &mem,
const mkldnn::memory &diff_dst_memory)
: pd(GetActBwdDescImpl(param, mem, diff_dst_memory, data.dtype())) {}

void SetNewMem(const mkldnn::memory &data,
const mkldnn::memory &diff_dst_memory,
const mkldnn::memory &diff_src_memory) {
if (this->bwd != nullptr) {
this->data->set_data_handle(data.get_data_handle());
this->diff_dst_memory->set_data_handle(diff_dst_memory.get_data_handle());
this->diff_src_memory->set_data_handle(diff_src_memory.get_data_handle());
} else {
this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
data.get_primitive_desc(), data.get_data_handle()));
this->diff_dst_memory = std::shared_ptr<mkldnn::memory>(
new mkldnn::memory(diff_dst_memory.get_primitive_desc(),
diff_dst_memory.get_data_handle()));
this->diff_src_memory = std::shared_ptr<mkldnn::memory>(
new mkldnn::memory(diff_src_memory.get_primitive_desc(),
diff_src_memory.get_data_handle()));
this->bwd = std::shared_ptr<mkldnn::eltwise_backward>(
new mkldnn::eltwise_backward(
this->pd, mkldnn::primitive::at(*this->data),
*this->diff_dst_memory, *this->diff_src_memory));
}
}

const inline mkldnn::eltwise_backward &GetBwd() const { return *bwd; }
};
const inline mkldnn::eltwise_backward &MKLDNNActBackward::GetBwd() const {
return *bwd;
}

static inline MKLDNNActBackward &GetActBackward(const ActivationParam &param,
const OpContext &ctx,
Expand Down Expand Up @@ -274,20 +203,23 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx
auto input_mem = in_buffer.GetMKLDNNData();
// We need to make sure the two inputs to eltwise_backward has the same memory
// descriptor. Otherwise, the perf will suffer.
if (input_mem->get_primitive_desc() != diff_dst_memory->get_primitive_desc())
input_mem = in_buffer.GetMKLDNNDataReorder(diff_dst_memory->get_primitive_desc());
if (input_mem->get_desc() != diff_dst_memory->get_desc())
input_mem = in_buffer.GetMKLDNNDataReorder(diff_dst_memory->get_desc());
MKLDNNActBackward &bwd =
GetActBackward(param, ctx, in_buffer, out_buffer, *input_mem);
MKLDNNStream *stream = MKLDNNStream::Get();
mkldnn_output_t diff_src_memory =
CreateMKLDNNMem(in_grad, bwd.pd.diff_src_primitive_desc(), req);
bwd.SetNewMem(*input_mem, *diff_dst_memory, *diff_src_memory.second);
stream->RegisterPrim(bwd.GetBwd());
CreateMKLDNNMem(in_grad, bwd.pd.diff_src_desc(), req);
mkldnn_args_map_t args = {
{ MKLDNN_ARG_SRC, *input_mem },
{ MKLDNN_ARG_DIFF_DST, *diff_dst_memory },
{ MKLDNN_ARG_DIFF_SRC, *diff_src_memory.second },
};
stream->RegisterPrimArgs(bwd.GetBwd(), args);
CommitOutput(in_grad, diff_src_memory);
stream->Submit();
}

} // namespace op
} // namespace mxnet

#endif
15 changes: 7 additions & 8 deletions src/operator/nn/mkldnn/mkldnn_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,6 @@ void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);

/* For activation */
void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,
const NDArray &out_data);
void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &out_grad, const NDArray &in_data,
const OpReqType &req, const NDArray &in_grad);

void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const NDArray &data,
Expand Down Expand Up @@ -133,6 +125,13 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);

/* For activation */
void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &in_data, const OpReqType &req,
const NDArray &out_data);
void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const NDArray &out_grad, const NDArray &in_data,
const OpReqType &req, const NDArray &in_grad);

void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2,
const mkldnn::memory &out);
Expand Down

0 comments on commit f930baa

Please sign in to comment.