Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[mkldnn-v1.0] Add MKL-DNN FC (#16221)
Browse files Browse the repository at this point in the history
* add mkldnn fc; pass lint; pass mnist training

* add TODO info for future debug
  • Loading branch information
rongzha1 authored and pengzhao-intel committed Sep 30, 2019
1 parent a559760 commit 23093e6
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 110 deletions.
12 changes: 11 additions & 1 deletion src/imperative/imperative_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -541,8 +541,18 @@ inline void PushOperator(const OpStatePtr& state,
// copying A to B may not happen, and will corrupt A's memory.
InvalidateOutputs(outputs, req);
}
// add for mkldnn OP + no mkldnn OP
const auto is_mkldnn = Op::GetAttr<bool>("TIsMKLDNN");
if (!is_mkldnn.get(attrs.op, false)) {
std::vector<NDArray> inputs_fallback;
CreateDefaultInputs(inputs, &inputs_fallback);
fcompute_ex(state, opctx, inputs_fallback, req, outputs);
} else {
#endif
fcompute_ex(state, opctx, inputs, req, outputs);
#if MXNET_USE_MKLDNN == 100
}
#endif
fcompute_ex(state, opctx, inputs, req, outputs);
if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync
&& rctx.get_stream<gpu>() && !rctx.is_bulk) {
rctx.get_stream<gpu>()->Wait();
Expand Down
12 changes: 6 additions & 6 deletions src/operator/nn/fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ void FullyConnectedComputeExCPU(const nnvm::NodeAttrs& attrs,
valid_bias = inputs[2].storage_type() == kDefaultStorage ||
inputs[2].storage_type() == kRowSparseStorage;
}
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
if (common::ContainsOnlyStorage(inputs, kDefaultStorage) &&
common::ContainsOnlyStorage(outputs, kDefaultStorage)) {
if (SupportMKLDNNFC(inputs[0])) {
Expand Down Expand Up @@ -141,7 +141,7 @@ void FullyConnectedComputeExCPU(const nnvm::NodeAttrs& attrs,
#endif
}

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
void FullyConnectedGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
Expand Down Expand Up @@ -199,7 +199,7 @@ inline static bool FCStorageType(const nnvm::NodeAttrs& attrs,
dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage,
dispatch_mode, DispatchMode::kFComputeEx);
}
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
if (!MKLDNNEnvSet())
*dispatch_mode = DispatchMode::kFComputeFallback;
#endif
Expand Down Expand Up @@ -233,7 +233,7 @@ inline static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs,
dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
}
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
if (!MKLDNNEnvSet())
*dispatch_mode = DispatchMode::kFComputeFallback;
#endif
Expand Down Expand Up @@ -295,7 +295,7 @@ If ``no_bias`` is set to be true, then the ``bias`` term is ignored.
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"output"};
})
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
Expand Down Expand Up @@ -326,7 +326,7 @@ NNVM_REGISTER_OP(_backward_FullyConnected)
})
.set_attr<FInferStorageType>("FInferStorageType", BackwardFCStorageType)
.set_attr_parser(ParamParser<FullyConnectedParam>)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", FullyConnectedGradComputeExCPU)
#endif
Expand Down
15 changes: 5 additions & 10 deletions src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FULLY_CONNECTED_INL_H_
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FULLY_CONNECTED_INL_H_

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100

#include <vector>
#include <string>
Expand All @@ -50,7 +50,7 @@ struct MKLDNNFCParam: public dmlc::Parameter<MKLDNNFCParam> {
DMLC_DECLARE_FIELD(enable_float_output).set_default(false)
.describe("Whether to enable float32 output");
DMLC_DECLARE_FIELD(with_eltwise).set_default(false)
.describe("Whether there's a post elemwise after FullyConnected operator");
.describe("Whether there's a post with_eltwise after FullyConnected operator");
DMLC_DECLARE_FIELD(min_calib_range)
.set_default(dmlc::optional<float>())
.describe("The minimum scalar value in the form of float32 obtained "
Expand Down Expand Up @@ -85,21 +85,16 @@ class MKLDNNFullyConnectedForward {
const NDArray &data, const NDArray &weight,
const NDArray *bias,
const mkldnn::memory::desc &out_md)
: fwd_pd(GetFCFwdImpl(full_param, is_train, data, weight, bias, out_md)) {}

void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight,
const mkldnn::memory *bias, const mkldnn::memory &output);
: fwd_pd(GetFCFwdImpl(full_param, is_train, data, weight, bias, out_md)) {
fwd_ = std::make_shared<mkldnn::inner_product_forward>(fwd_pd);
}

const mkldnn::inner_product_forward &GetFwd() const {
return *fwd_;
}

private:
std::shared_ptr<mkldnn::inner_product_forward> fwd_;
std::shared_ptr<mkldnn::memory> data_;
std::shared_ptr<mkldnn::memory> weight_;
std::shared_ptr<mkldnn::memory> bias_;
std::shared_ptr<mkldnn::memory> out_;
};

typedef ParamOpSign<FullyConnectedParam> MKLDNNFullyconSignature;
Expand Down
128 changes: 45 additions & 83 deletions src/operator/nn/mkldnn/mkldnn_fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
* \author Da Zheng, Ciyong Chen
*/

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include "mkldnn_fully_connected-inl.h"

namespace mxnet {
Expand Down Expand Up @@ -67,7 +67,6 @@ mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl(
}

attr.set_output_scales(mask, scales);
attr.set_int_output_round_mode(round_nearest);
}
}

Expand Down Expand Up @@ -130,51 +129,6 @@ inline static mkldnn::inner_product_backward_weights::primitive_desc GetFCBwdWei
}
}

void MKLDNNFullyConnectedForward::SetNewMem(const mkldnn::memory &data,
const mkldnn::memory &weight,
const mkldnn::memory *bias,
const mkldnn::memory &output) {
if (this->data_ == nullptr)
this->data_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
fwd_pd.src_primitive_desc(), data.get_data_handle()));
else
this->data_->set_data_handle(data.get_data_handle());

if (this->weight_ == nullptr)
this->weight_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
fwd_pd.weights_primitive_desc(), weight.get_data_handle()));
else
this->weight_->set_data_handle(weight.get_data_handle());

if (this->out_ == nullptr)
this->out_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
fwd_pd.dst_primitive_desc(), output.get_data_handle()));
else
this->out_->set_data_handle(output.get_data_handle());

if (bias != nullptr) {
if (this->bias_ == nullptr)
this->bias_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
fwd_pd.bias_primitive_desc(), bias->get_data_handle()));
else
this->bias_->set_data_handle(bias->get_data_handle());

if (this->fwd_ == nullptr)
this->fwd_ = std::shared_ptr<mkldnn::inner_product_forward>(
new mkldnn::inner_product_forward(
fwd_pd, mkldnn::primitive::at(*this->data_),
mkldnn::primitive::at(*this->weight_),
mkldnn::primitive::at(*this->bias_), *this->out_));
} else {
if (this->fwd_ == nullptr) {
this->fwd_ = std::shared_ptr<mkldnn::inner_product_forward>(
new mkldnn::inner_product_forward(
fwd_pd, mkldnn::primitive::at(*this->data_),
mkldnn::primitive::at(*this->weight_), *this->out_));
}
}
}

MKLDNNFullyConnectedForward &GetFCFwd(
const FullyConnectedParam &param, const bool is_train,
const NDArray &data, const NDArray &weight,
Expand Down Expand Up @@ -223,13 +177,13 @@ void MKLDNNFCFlattenData(const FullyConnectedParam &param,
mkldnn::memory::dims out_dims{static_cast<int>(oshape.ProdShape(0, oshape.ndim()-1)),
static_cast<int>(oshape[ishape.ndim()-1])};
*out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data.dtype()),
mkldnn::memory::format::any);
mkldnn::memory::format_tag::any);
} else {
*in_data = in_data->MKLDNNDataReshape(Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim())));
mkldnn::memory::dims out_dims{static_cast<int>(oshape[0]),
static_cast<int>(oshape.ProdShape(1, oshape.ndim()))};
*out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data.dtype()),
mkldnn::memory::format::any);
mkldnn::memory::format_tag::any);
}
}
}
Expand All @@ -244,35 +198,35 @@ void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam &full_param,
NDArray weight = in_data[fullc::kWeight];
NDArray data = in_data[fullc::kData];

auto data_mem = data.GetMKLDNNDataReorder(fwd->fwd_pd.src_primitive_desc());
auto data_mem = data.GetMKLDNNDataReorder(fwd->fwd_pd.src_desc());
const mkldnn::memory *weight_mem;
if (ctx.is_train) {
if (weight.IsMKLDNNData()) {
weight.Reorder2DefaultAsync();
}
weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), 1);
weight_mem = GetWeights(weight, fwd->fwd_pd.weights_desc(), 1);
} else {
if (weight.IsDefaultData()) {
// We also need to modify the layout on the original weight array.
// Don't switch below sequence because naive engine will executes
// pushAsync synchronously.
weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_primitive_desc());
weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), 1);
} else {
weight_mem = weight.GetMKLDNNData();
CHECK(weight_mem->get_primitive_desc() == fwd->fwd_pd.weights_primitive_desc());
weight_mem = weight.GetMKLDNNData();
if (weight_mem->get_desc() != fwd->fwd_pd.weights_desc()) {
// TODO(rongzha1): rm following line for ut:test_contrib_rnn, need debug
// weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_desc());
weight_mem = GetWeights(weight, fwd->fwd_pd.weights_desc(), 1);
}
}
auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut],
fwd->fwd_pd.dst_primitive_desc(), req[fullc::kOut], &data);
fwd->fwd_pd.dst_desc(), req[fullc::kOut], &data);

std::unordered_map<int, mkldnn::memory> args = {
{MKLDNN_ARG_SRC, *data_mem},
{MKLDNN_ARG_WEIGHTS, *weight_mem},
{MKLDNN_ARG_DST, *out_mem.second},
};
if (!full_param.default_param.no_bias) {
auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder(
fwd->fwd_pd.bias_primitive_desc());
fwd->SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second);
} else {
fwd->SetNewMem(*data_mem, *weight_mem, nullptr, *out_mem.second);
fwd->fwd_pd.bias_desc());
args.insert({ MKLDNN_ARG_BIAS, *bias_mem});
}
MKLDNNStream::Get()->RegisterPrim(fwd->GetFwd());
MKLDNNStream::Get()->RegisterPrimArgs(fwd->GetFwd(), args);
CommitOutput(out_data[fullc::kOut], out_mem);
MKLDNNStream::Get()->Submit();
}
Expand Down Expand Up @@ -339,37 +293,45 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetFCBwdData(
data, weight, out_grad, fwd_pd);
auto out_grad_mem = out_grad.GetMKLDNNDataReorder(
ipBwdData_pd.diff_dst_primitive_desc());
auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_primitive_desc());
ipBwdData_pd.diff_dst_desc());
auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_desc());
auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData],
ipBwdData_pd.diff_src_primitive_desc(),
ipBwdData_pd.diff_src_desc(),
req[fullc::kData]);
MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_data(
ipBwdData_pd, *out_grad_mem, *weight_mem, *in_grad_mem.second));
std::unordered_map<int, mkldnn::memory> args = {
{MKLDNN_ARG_DIFF_DST, *out_grad_mem},
{MKLDNN_ARG_WEIGHTS, *weight_mem},
{MKLDNN_ARG_DIFF_SRC, *in_grad_mem.second}
};

MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::inner_product_backward_data(ipBwdData_pd), args);
CommitOutput(in_grad[fullc::kData], in_grad_mem);
}
if (req[fullc::kWeight]) {
mkldnn::inner_product_backward_weights::primitive_desc ipBwdWeights_pd
= GetFCBwdWeights(data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias],
out_grad, fwd_pd);
auto out_grad_mem = out_grad.GetMKLDNNDataReorder(
ipBwdWeights_pd.diff_dst_primitive_desc());
auto data_mem = data.GetMKLDNNDataReorder(ipBwdWeights_pd.src_primitive_desc());
ipBwdWeights_pd.diff_dst_desc());
auto data_mem = data.GetMKLDNNDataReorder(ipBwdWeights_pd.src_desc());
auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[fullc::kWeight],
ipBwdWeights_pd.diff_weights_primitive_desc(),
ipBwdWeights_pd.diff_weights_desc(),
req[fullc::kWeight]);
std::unordered_map<int, mkldnn::memory> args = {
{MKLDNN_ARG_DIFF_DST, *out_grad_mem},
{MKLDNN_ARG_SRC, *data_mem},
{MKLDNN_ARG_DIFF_WEIGHTS, *in_grad_weight.second},
};

mkldnn_output_t in_grad_bias;
if (param.no_bias) {
MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_weights(
ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second));
} else {
if (!param.no_bias) {
in_grad_bias = CreateMKLDNNMem(in_grad[fullc::kBias],
ipBwdWeights_pd.diff_bias_primitive_desc(),
ipBwdWeights_pd.diff_bias_desc(),
req[fullc::kBias]);
MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_weights(
ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second,
*in_grad_bias.second));
args.insert({MKLDNN_ARG_DIFF_BIAS, *in_grad_bias.second});
}
MKLDNNStream::Get()->RegisterPrimArgs(
mkldnn::inner_product_backward_weights(ipBwdWeights_pd), args);
CommitOutput(in_grad[fullc::kWeight], in_grad_weight);
CommitOutput(in_grad[fullc::kBias], in_grad_bias);
}
Expand All @@ -378,4 +340,4 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,

} // namespace op
} // namespace mxnet
#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_USE_MKLDNN == 100
20 changes: 10 additions & 10 deletions src/operator/nn/mkldnn/mkldnn_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,6 @@ namespace mxnet {
namespace op {

#if MXNET_USE_MKLDNN == 1
/* For fully connected. */
void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data);
void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs);

/* For deconvolution */
void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
Expand Down Expand Up @@ -104,6 +94,16 @@ void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs,
#endif

#if MXNET_USE_MKLDNN == 100
/* For fully connected. */
void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data);
void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs);

/* For convolution. */
void MKLDNNConvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
Expand Down

0 comments on commit 23093e6

Please sign in to comment.