Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[mkldnn-v1.0] Add MKL-DNN FC #16221

Merged
merged 2 commits into from
Sep 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/imperative/imperative_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -541,8 +541,18 @@ inline void PushOperator(const OpStatePtr& state,
// copying A to B may not happen, and will corrupt A's memory.
InvalidateOutputs(outputs, req);
}
// add for mkldnn OP + no mkldnn OP
const auto is_mkldnn = Op::GetAttr<bool>("TIsMKLDNN");
if (!is_mkldnn.get(attrs.op, false)) {
std::vector<NDArray> inputs_fallback;
CreateDefaultInputs(inputs, &inputs_fallback);
fcompute_ex(state, opctx, inputs_fallback, req, outputs);
} else {
#endif
fcompute_ex(state, opctx, inputs, req, outputs);
#if MXNET_USE_MKLDNN == 100
}
#endif
fcompute_ex(state, opctx, inputs, req, outputs);
if (ctx.dev_mask() == gpu::kDevMask && exec_type == ExecType::kSync
&& rctx.get_stream<gpu>() && !rctx.is_bulk) {
rctx.get_stream<gpu>()->Wait();
Expand Down
12 changes: 6 additions & 6 deletions src/operator/nn/fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ void FullyConnectedComputeExCPU(const nnvm::NodeAttrs& attrs,
valid_bias = inputs[2].storage_type() == kDefaultStorage ||
inputs[2].storage_type() == kRowSparseStorage;
}
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
if (common::ContainsOnlyStorage(inputs, kDefaultStorage) &&
common::ContainsOnlyStorage(outputs, kDefaultStorage)) {
if (SupportMKLDNNFC(inputs[0])) {
Expand Down Expand Up @@ -141,7 +141,7 @@ void FullyConnectedComputeExCPU(const nnvm::NodeAttrs& attrs,
#endif
}

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
void FullyConnectedGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const OpContext &ctx,
const std::vector<NDArray> &inputs,
Expand Down Expand Up @@ -199,7 +199,7 @@ inline static bool FCStorageType(const nnvm::NodeAttrs& attrs,
dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage,
dispatch_mode, DispatchMode::kFComputeEx);
}
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
if (!MKLDNNEnvSet())
*dispatch_mode = DispatchMode::kFComputeFallback;
#endif
Expand Down Expand Up @@ -233,7 +233,7 @@ inline static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs,
dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage,
dispatch_mode, DispatchMode::kFCompute);
}
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
if (!MKLDNNEnvSet())
*dispatch_mode = DispatchMode::kFComputeFallback;
#endif
Expand Down Expand Up @@ -295,7 +295,7 @@ If ``no_bias`` is set to be true, then the ``bias`` term is ignored.
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"output"};
})
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
Expand Down Expand Up @@ -326,7 +326,7 @@ NNVM_REGISTER_OP(_backward_FullyConnected)
})
.set_attr<FInferStorageType>("FInferStorageType", BackwardFCStorageType)
.set_attr_parser(ParamParser<FullyConnectedParam>)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", FullyConnectedGradComputeExCPU)
#endif
Expand Down
15 changes: 5 additions & 10 deletions src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FULLY_CONNECTED_INL_H_
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FULLY_CONNECTED_INL_H_

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100

#include <vector>
#include <string>
Expand All @@ -50,7 +50,7 @@ struct MKLDNNFCParam: public dmlc::Parameter<MKLDNNFCParam> {
DMLC_DECLARE_FIELD(enable_float_output).set_default(false)
.describe("Whether to enable float32 output");
DMLC_DECLARE_FIELD(with_eltwise).set_default(false)
.describe("Whether there's a post elemwise after FullyConnected operator");
.describe("Whether there's a post with_eltwise after FullyConnected operator");
DMLC_DECLARE_FIELD(min_calib_range)
.set_default(dmlc::optional<float>())
.describe("The minimum scalar value in the form of float32 obtained "
Expand Down Expand Up @@ -85,21 +85,16 @@ class MKLDNNFullyConnectedForward {
const NDArray &data, const NDArray &weight,
const NDArray *bias,
const mkldnn::memory::desc &out_md)
: fwd_pd(GetFCFwdImpl(full_param, is_train, data, weight, bias, out_md)) {}

void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight,
const mkldnn::memory *bias, const mkldnn::memory &output);
: fwd_pd(GetFCFwdImpl(full_param, is_train, data, weight, bias, out_md)) {
fwd_ = std::make_shared<mkldnn::inner_product_forward>(fwd_pd);
}

const mkldnn::inner_product_forward &GetFwd() const {
return *fwd_;
}

private:
std::shared_ptr<mkldnn::inner_product_forward> fwd_;
std::shared_ptr<mkldnn::memory> data_;
std::shared_ptr<mkldnn::memory> weight_;
std::shared_ptr<mkldnn::memory> bias_;
std::shared_ptr<mkldnn::memory> out_;
};

typedef ParamOpSign<FullyConnectedParam> MKLDNNFullyconSignature;
Expand Down
128 changes: 45 additions & 83 deletions src/operator/nn/mkldnn/mkldnn_fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
* \author Da Zheng, Ciyong Chen
*/

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include "mkldnn_fully_connected-inl.h"

namespace mxnet {
Expand Down Expand Up @@ -67,7 +67,6 @@ mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl(
}

attr.set_output_scales(mask, scales);
attr.set_int_output_round_mode(round_nearest);
}
}

Expand Down Expand Up @@ -130,51 +129,6 @@ inline static mkldnn::inner_product_backward_weights::primitive_desc GetFCBwdWei
}
}

void MKLDNNFullyConnectedForward::SetNewMem(const mkldnn::memory &data,
const mkldnn::memory &weight,
const mkldnn::memory *bias,
const mkldnn::memory &output) {
if (this->data_ == nullptr)
this->data_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
fwd_pd.src_primitive_desc(), data.get_data_handle()));
else
this->data_->set_data_handle(data.get_data_handle());

if (this->weight_ == nullptr)
this->weight_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
fwd_pd.weights_primitive_desc(), weight.get_data_handle()));
else
this->weight_->set_data_handle(weight.get_data_handle());

if (this->out_ == nullptr)
this->out_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
fwd_pd.dst_primitive_desc(), output.get_data_handle()));
else
this->out_->set_data_handle(output.get_data_handle());

if (bias != nullptr) {
if (this->bias_ == nullptr)
this->bias_ = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
fwd_pd.bias_primitive_desc(), bias->get_data_handle()));
else
this->bias_->set_data_handle(bias->get_data_handle());

if (this->fwd_ == nullptr)
this->fwd_ = std::shared_ptr<mkldnn::inner_product_forward>(
new mkldnn::inner_product_forward(
fwd_pd, mkldnn::primitive::at(*this->data_),
mkldnn::primitive::at(*this->weight_),
mkldnn::primitive::at(*this->bias_), *this->out_));
} else {
if (this->fwd_ == nullptr) {
this->fwd_ = std::shared_ptr<mkldnn::inner_product_forward>(
new mkldnn::inner_product_forward(
fwd_pd, mkldnn::primitive::at(*this->data_),
mkldnn::primitive::at(*this->weight_), *this->out_));
}
}
}

MKLDNNFullyConnectedForward &GetFCFwd(
const FullyConnectedParam &param, const bool is_train,
const NDArray &data, const NDArray &weight,
Expand Down Expand Up @@ -223,13 +177,13 @@ void MKLDNNFCFlattenData(const FullyConnectedParam &param,
mkldnn::memory::dims out_dims{static_cast<int>(oshape.ProdShape(0, oshape.ndim()-1)),
static_cast<int>(oshape[ishape.ndim()-1])};
*out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data.dtype()),
mkldnn::memory::format::any);
mkldnn::memory::format_tag::any);
} else {
*in_data = in_data->MKLDNNDataReshape(Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim())));
mkldnn::memory::dims out_dims{static_cast<int>(oshape[0]),
static_cast<int>(oshape.ProdShape(1, oshape.ndim()))};
*out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data.dtype()),
mkldnn::memory::format::any);
mkldnn::memory::format_tag::any);
}
}
}
Expand All @@ -244,35 +198,35 @@ void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam &full_param,
NDArray weight = in_data[fullc::kWeight];
NDArray data = in_data[fullc::kData];

auto data_mem = data.GetMKLDNNDataReorder(fwd->fwd_pd.src_primitive_desc());
auto data_mem = data.GetMKLDNNDataReorder(fwd->fwd_pd.src_desc());
const mkldnn::memory *weight_mem;
if (ctx.is_train) {
if (weight.IsMKLDNNData()) {
weight.Reorder2DefaultAsync();
}
weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), 1);
weight_mem = GetWeights(weight, fwd->fwd_pd.weights_desc(), 1);
} else {
if (weight.IsDefaultData()) {
// We also need to modify the layout on the original weight array.
// Don't switch below sequence because naive engine will executes
// pushAsync synchronously.
weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_primitive_desc());
weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), 1);
} else {
weight_mem = weight.GetMKLDNNData();
CHECK(weight_mem->get_primitive_desc() == fwd->fwd_pd.weights_primitive_desc());
weight_mem = weight.GetMKLDNNData();
if (weight_mem->get_desc() != fwd->fwd_pd.weights_desc()) {
// TODO(rongzha1): rm following line for ut:test_contrib_rnn, need debug
// weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_desc());
weight_mem = GetWeights(weight, fwd->fwd_pd.weights_desc(), 1);
}
}
auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut],
fwd->fwd_pd.dst_primitive_desc(), req[fullc::kOut], &data);
fwd->fwd_pd.dst_desc(), req[fullc::kOut], &data);

std::unordered_map<int, mkldnn::memory> args = {
{MKLDNN_ARG_SRC, *data_mem},
{MKLDNN_ARG_WEIGHTS, *weight_mem},
{MKLDNN_ARG_DST, *out_mem.second},
};
if (!full_param.default_param.no_bias) {
auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder(
fwd->fwd_pd.bias_primitive_desc());
fwd->SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second);
} else {
fwd->SetNewMem(*data_mem, *weight_mem, nullptr, *out_mem.second);
fwd->fwd_pd.bias_desc());
args.insert({ MKLDNN_ARG_BIAS, *bias_mem});
}
MKLDNNStream::Get()->RegisterPrim(fwd->GetFwd());
MKLDNNStream::Get()->RegisterPrimArgs(fwd->GetFwd(), args);
CommitOutput(out_data[fullc::kOut], out_mem);
MKLDNNStream::Get()->Submit();
}
Expand Down Expand Up @@ -339,37 +293,45 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetFCBwdData(
data, weight, out_grad, fwd_pd);
auto out_grad_mem = out_grad.GetMKLDNNDataReorder(
ipBwdData_pd.diff_dst_primitive_desc());
auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_primitive_desc());
ipBwdData_pd.diff_dst_desc());
auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_desc());
auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData],
ipBwdData_pd.diff_src_primitive_desc(),
ipBwdData_pd.diff_src_desc(),
req[fullc::kData]);
MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_data(
ipBwdData_pd, *out_grad_mem, *weight_mem, *in_grad_mem.second));
std::unordered_map<int, mkldnn::memory> args = {
{MKLDNN_ARG_DIFF_DST, *out_grad_mem},
{MKLDNN_ARG_WEIGHTS, *weight_mem},
{MKLDNN_ARG_DIFF_SRC, *in_grad_mem.second}
};

MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::inner_product_backward_data(ipBwdData_pd), args);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to cache backward_data/backward_weights (create only once) instead of creating them every time when call it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FC backward feature was not enabled as mkldnn 0.2. Maybe leave it when enable this feature?

CommitOutput(in_grad[fullc::kData], in_grad_mem);
}
if (req[fullc::kWeight]) {
mkldnn::inner_product_backward_weights::primitive_desc ipBwdWeights_pd
= GetFCBwdWeights(data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias],
out_grad, fwd_pd);
auto out_grad_mem = out_grad.GetMKLDNNDataReorder(
ipBwdWeights_pd.diff_dst_primitive_desc());
auto data_mem = data.GetMKLDNNDataReorder(ipBwdWeights_pd.src_primitive_desc());
ipBwdWeights_pd.diff_dst_desc());
auto data_mem = data.GetMKLDNNDataReorder(ipBwdWeights_pd.src_desc());
auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[fullc::kWeight],
ipBwdWeights_pd.diff_weights_primitive_desc(),
ipBwdWeights_pd.diff_weights_desc(),
req[fullc::kWeight]);
std::unordered_map<int, mkldnn::memory> args = {
{MKLDNN_ARG_DIFF_DST, *out_grad_mem},
{MKLDNN_ARG_SRC, *data_mem},
{MKLDNN_ARG_DIFF_WEIGHTS, *in_grad_weight.second},
};

mkldnn_output_t in_grad_bias;
if (param.no_bias) {
MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_weights(
ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second));
} else {
if (!param.no_bias) {
in_grad_bias = CreateMKLDNNMem(in_grad[fullc::kBias],
ipBwdWeights_pd.diff_bias_primitive_desc(),
ipBwdWeights_pd.diff_bias_desc(),
req[fullc::kBias]);
MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_weights(
ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second,
*in_grad_bias.second));
args.insert({MKLDNN_ARG_DIFF_BIAS, *in_grad_bias.second});
}
MKLDNNStream::Get()->RegisterPrimArgs(
mkldnn::inner_product_backward_weights(ipBwdWeights_pd), args);
CommitOutput(in_grad[fullc::kWeight], in_grad_weight);
CommitOutput(in_grad[fullc::kBias], in_grad_bias);
}
Expand All @@ -378,4 +340,4 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,

} // namespace op
} // namespace mxnet
#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_USE_MKLDNN == 100
20 changes: 10 additions & 10 deletions src/operator/nn/mkldnn/mkldnn_ops-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,16 +44,6 @@ namespace mxnet {
namespace op {

#if MXNET_USE_MKLDNN == 1
/* For fully connected. */
void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data);
void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs);

/* For deconvolution */
void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
Expand Down Expand Up @@ -104,6 +94,16 @@ void MKLDNNFlattenForward(const nnvm::NodeAttrs &attrs,
#endif

#if MXNET_USE_MKLDNN == 100
/* For fully connected. */
void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data);
void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs);

/* For convolution. */
void MKLDNNConvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
Expand Down