forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* rebase to Da,Zheng refactor branch Jan.14, add signature for mkldnn Deconv and modify classMKLDNNDeconvForward * fix make lint complains
- Loading branch information
Showing
2 changed files
with
176 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,14 +20,15 @@ | |
/*! | ||
* \file mkldnn_deconvolution.cc | ||
* \brief | ||
* \author Da Zheng | ||
* \author Da Zheng, Rong Zhang ([email protected]) | ||
*/ | ||
|
||
#if MXNET_USE_MKLDNN == 1 | ||
|
||
#include "../deconvolution-inl.h" | ||
#include "./mkldnn_ops-inl.h" | ||
#include "./mkldnn_base-inl.h" | ||
|
||
#if MXNET_USE_MKLDNN == 1 | ||
namespace mxnet { | ||
namespace op { | ||
|
||
|
@@ -59,7 +60,7 @@ static mkldnn::convolution_forward::primitive_desc GetDeconvBwd_( | |
} | ||
} | ||
|
||
static mkldnn::convolution_backward_data::primitive_desc GetDeconvFwd( | ||
static mkldnn::convolution_backward_data::primitive_desc GetDeconvFwdImpl( | ||
const DeconvolutionParam& param, const NDArray &data, const NDArray &weights, | ||
bool has_bias, const NDArray &output) { | ||
auto data_md = GetMemDesc(data); | ||
|
@@ -70,11 +71,21 @@ static mkldnn::convolution_backward_data::primitive_desc GetDeconvFwd( | |
if (param.stride.ndim() == 2) { | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[1]; | ||
} else if (param.stride.ndim() == 1) { | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[0]; | ||
} else { | ||
LOG(FATAL) << "Unsupported stride dim"; | ||
} | ||
mkldnn::memory::dims padding{0, 0}; | ||
if (param.pad.ndim() == 2) { | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[1]; | ||
} else if (param.pad.ndim() == 1) { | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[0]; | ||
} else { | ||
LOG(FATAL) << "Unsupported pad dim"; | ||
} | ||
mkldnn::memory::dims dilate{0, 0}; | ||
if (param.dilate.ndim() == 2) { | ||
|
@@ -100,11 +111,21 @@ static mkldnn::convolution_forward::primitive_desc GetDeconvBwdData( | |
if (param.stride.ndim() == 2) { | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[1]; | ||
} else if (param.stride.ndim() == 1) { | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[0]; | ||
} else { | ||
LOG(FATAL) << "Unsupported stride dim"; | ||
} | ||
mkldnn::memory::dims padding{0, 0}; | ||
if (param.pad.ndim() == 2) { | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[1]; | ||
} else if (param.pad.ndim() == 1) { | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[0]; | ||
} else { | ||
LOG(FATAL) << "Unsupported pad dim"; | ||
} | ||
mkldnn::memory::dims dilate{0, 0}; | ||
if (param.dilate.ndim() == 2) { | ||
|
@@ -127,11 +148,21 @@ static mkldnn::convolution_backward_weights::primitive_desc GetDeconvBwdWeights( | |
if (param.stride.ndim() == 2) { | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[1]; | ||
} else if (param.stride.ndim() == 1) { | ||
strides[0] = param.stride[0]; | ||
strides[1] = param.stride[0]; | ||
} else { | ||
LOG(FATAL) << "Unsupported stride dim"; | ||
} | ||
mkldnn::memory::dims padding{0, 0}; | ||
if (param.pad.ndim() == 2) { | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[1]; | ||
} else if (param.pad.ndim() == 1) { | ||
padding[0] = param.pad[0]; | ||
padding[1] = param.pad[0]; | ||
} else { | ||
LOG(FATAL) << "Unsupported pad dim"; | ||
} | ||
mkldnn::memory::dims dilate{0, 0}; | ||
if (param.dilate.ndim() == 2) { | ||
|
@@ -151,42 +182,93 @@ static mkldnn::convolution_backward_weights::primitive_desc GetDeconvBwdWeights( | |
} | ||
} | ||
|
||
void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, | ||
const std::vector<NDArray> &in_data, | ||
const std::vector<OpReqType> &req, | ||
const std::vector<NDArray> &out_data) { | ||
TmpMemMgr::Get()->Init(ctx.requested[deconv::kTempSpace]); | ||
const DeconvolutionParam& param = nnvm::get<DeconvolutionParam>(attrs.parsed); | ||
class MKLDNNDeconvForward { | ||
std::shared_ptr<mkldnn::convolution_backward_data> fwd; | ||
std::shared_ptr<mkldnn::memory> data; | ||
std::shared_ptr<mkldnn::memory> weight; | ||
std::shared_ptr<mkldnn::memory> bias; | ||
std::shared_ptr<mkldnn::memory> out; | ||
OutDataOp data_op; | ||
|
||
public: | ||
MKLDNNDeconvForward(const DeconvolutionParam& param, | ||
const NDArray &data, | ||
const NDArray &weights, | ||
bool has_bias, | ||
const NDArray &output); | ||
void SetDataHandle(const DeconvolutionParam& param, | ||
const OpContext &ctx, | ||
const std::vector<NDArray> &in_data, | ||
const std::vector<OpReqType> &req, | ||
const std::vector<NDArray> &out_data); | ||
|
||
void Execute(const std::vector<NDArray> &out_data); | ||
|
||
private: | ||
mkldnn::convolution_backward_data::primitive_desc fwd_pd; | ||
}; // class MKLDNNDeconvForward | ||
|
||
mkldnn::convolution_backward_data::primitive_desc deconvFwd_pd = GetDeconvFwd( | ||
param, in_data[deconv::kData], in_data[deconv::kWeight], false, | ||
out_data[deconv::kOut]); | ||
MKLDNNDeconvForward::MKLDNNDeconvForward(const DeconvolutionParam& param, | ||
const NDArray &data, | ||
const NDArray &weights, | ||
bool has_bias, | ||
const NDArray &output) | ||
:fwd_pd(GetDeconvFwdImpl(param, data, weights, has_bias, output)) { | ||
this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory( | ||
fwd_pd.diff_dst_primitive_desc())); | ||
this->weight = std::shared_ptr<mkldnn::memory>(new mkldnn::memory( | ||
fwd_pd.weights_primitive_desc())); | ||
this->out = std::shared_ptr<mkldnn::memory>(new mkldnn::memory( | ||
fwd_pd.diff_src_primitive_desc())); | ||
this->fwd = std::shared_ptr<mkldnn::convolution_backward_data>( | ||
new mkldnn::convolution_backward_data(fwd_pd, | ||
mkldnn::primitive::at(*this->data), | ||
mkldnn::primitive::at(*this->weight), | ||
*this->out)); | ||
} | ||
|
||
void MKLDNNDeconvForward::SetDataHandle(const DeconvolutionParam& param, | ||
const OpContext &ctx, | ||
const std::vector<NDArray> &in_data, | ||
const std::vector<OpReqType> &req, | ||
const std::vector<NDArray> &out_data) { | ||
auto data_mem = in_data[deconv::kData].GetMKLDNNDataReorder( | ||
deconvFwd_pd.diff_dst_primitive_desc()); | ||
fwd_pd.diff_dst_primitive_desc()); | ||
const mkldnn::memory *weight_mem; | ||
if (ctx.is_train) { | ||
// TODO(zhengda) kvstore doesn't handle MKLDNN correctly. Let's reorder it | ||
// to the default format for now. | ||
if (in_data[deconv::kWeight].IsMKLDNN()) | ||
const_cast<NDArray &>(in_data[deconv::kWeight]).Reorder2Default(); | ||
weight_mem = GetWeights(in_data[deconv::kWeight], | ||
deconvFwd_pd.weights_primitive_desc(), | ||
fwd_pd.weights_primitive_desc(), | ||
param.num_group); | ||
} else { | ||
// For inference, we want to reorder the weight array so we don't need to | ||
// reorder data every time. | ||
const_cast<NDArray &>(in_data[deconv::kWeight]).Reorder( | ||
deconvFwd_pd.weights_primitive_desc()); | ||
fwd_pd.weights_primitive_desc()); | ||
weight_mem = in_data[deconv::kWeight].GetMKLDNNData(); | ||
} | ||
auto out_mem = CreateMKLDNNMem(out_data[deconv::kOut], | ||
deconvFwd_pd.diff_src_primitive_desc(), | ||
req[deconv::kOut]); | ||
fwd_pd.diff_src_primitive_desc(), req[deconv::kOut]); | ||
auto output = out_mem.second; | ||
this->data->set_data_handle(data_mem->get_data_handle()); | ||
this->weight->set_data_handle(weight_mem->get_data_handle()); | ||
this->out->set_data_handle(output->get_data_handle()); | ||
this->data_op = out_mem.first; | ||
} | ||
|
||
MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_data( | ||
deconvFwd_pd, *data_mem, *weight_mem, *out_mem.second)); | ||
CommitOutput(out_data[deconv::kOut], out_mem); | ||
void MKLDNNDeconvForward::Execute(const std::vector<NDArray> &out_data) { | ||
MKLDNNStream::Get()->RegisterPrim(*fwd); | ||
CommitOutput(out_data[deconv::kOut], mkldnn_output_t(this->data_op, this->out.get())); | ||
MKLDNNStream::Get()->Submit(); | ||
} | ||
|
||
static void MKLDNNDeconvFwdBiasPostProcess(const DeconvolutionParam& param, | ||
const OpContext &ctx, | ||
const std::vector<NDArray> &in_data, | ||
const std::vector<NDArray> &out_data) { | ||
// add bias, broadcast bias to dim 1: channel | ||
if (!param.no_bias) { | ||
// MKLDNN only supports float right now. | ||
|
@@ -201,6 +283,55 @@ void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &c | |
} | ||
} | ||
|
||
typedef MKLDNNParamOpSign<DeconvolutionParam> MKLDNNDeconvSignature; | ||
|
||
static inline MKLDNNDeconvForward &GetDeconvFwd( | ||
const nnvm::NodeAttrs& attrs, const NDArray &data, | ||
const NDArray &weights, const NDArray *bias, | ||
const NDArray &output) { | ||
static thread_local | ||
std::unordered_map<MKLDNNDeconvSignature, MKLDNNDeconvForward, MKLDNNOpHash> fwds; | ||
const DeconvolutionParam& param = nnvm::get<DeconvolutionParam>(attrs.parsed); | ||
MKLDNNDeconvSignature key(param); | ||
// Here we can sign the conv op with NDArray because conv primitive will | ||
// decide the right layout for the, so we only need to get the shape and the | ||
// data type of the arrays. | ||
key.AddSign(data); | ||
key.AddSign(weights); | ||
key.AddSign(output); | ||
if (bias) | ||
key.AddSign(*bias); | ||
|
||
auto it = fwds.find(key); | ||
if (it == fwds.end()) { | ||
bool has_bias = (bias != nullptr); | ||
MKLDNNDeconvForward fwd(param, data, weights, has_bias, output); | ||
auto ins_ret = fwds.insert( | ||
std::pair<MKLDNNDeconvSignature, MKLDNNDeconvForward>(key, fwd)); | ||
CHECK(ins_ret.second); | ||
it = ins_ret.first; | ||
} | ||
return it->second; | ||
} | ||
|
||
void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, | ||
const std::vector<NDArray> &in_data, | ||
const std::vector<OpReqType> &req, | ||
const std::vector<NDArray> &out_data) { | ||
TmpMemMgr::Get()->Init(ctx.requested[deconv::kTempSpace]); | ||
const DeconvolutionParam& param = nnvm::get<DeconvolutionParam>(attrs.parsed); | ||
|
||
MKLDNNDeconvForward &deconvFwd = GetDeconvFwd( | ||
attrs, in_data[deconv::kData], in_data[deconv::kWeight], | ||
param.no_bias ? nullptr : &in_data[deconv::kBias], out_data[deconv::kOut]); | ||
|
||
deconvFwd.SetDataHandle(param, ctx, in_data, req, out_data); | ||
|
||
deconvFwd.Execute(out_data); | ||
|
||
MKLDNNDeconvFwdBiasPostProcess(param, ctx, in_data, out_data); | ||
} | ||
|
||
void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, | ||
const std::vector<NDArray>& inputs, | ||
const std::vector<OpReqType>& req, | ||
|