Skip to content

Commit

Permalink
Rebase and Refactor deconv (#20)
Browse files Browse the repository at this point in the history
* rebase to Da,Zheng refactor branch Jan.14,  add signature for mkldnn Deconv and modify classMKLDNNDeconvForward

* fix make lint complains
  • Loading branch information
rongzha1 authored and zheng-da committed Jan 15, 2018
1 parent 2429c56 commit ea46c2f
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 20 deletions.
25 changes: 25 additions & 0 deletions src/operator/nn/deconvolution-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,31 @@ struct DeconvolutionParam : public dmlc::Parameter<DeconvolutionParam> {
this->cudnn_off == other.cudnn_off &&
this->layout == other.layout;
}
#if MXNET_USE_MKLDNN == 1
static uint64_t ComputeHash(const TShape &shape) {
uint64_t hash = 0;
for (size_t i = 0; i < shape.ndim(); i++)
hash = hash * 2 + shape[i];
return hash;
}

uint64_t GetHash() const {
uint64_t hash = 0;
hash = hash * 2 + ComputeHash(kernel);
hash = hash * 2 + ComputeHash(stride);
hash = hash * 2 + ComputeHash(dilate);
hash = hash * 2 + ComputeHash(pad);
hash = hash * 2 + ComputeHash(adj);
hash = hash * 2 + ComputeHash(target_shape);
hash = hash * 2 + num_filter;
hash = hash * 2 + num_group;
hash = hash * 2 + workspace;
hash = hash * 2 + no_bias;
if (layout.has_value())
hash = hash * 2 + layout.value();
return hash;
}
#endif
};

} // namespace op
Expand Down
171 changes: 151 additions & 20 deletions src/operator/nn/mkldnn/mkldnn_deconvolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
/*!
* \file mkldnn_deconvolution.cc
* \brief
* \author Da Zheng
* \author Da Zheng, Rong Zhang ([email protected])
*/

#if MXNET_USE_MKLDNN == 1

#include "../deconvolution-inl.h"
#include "./mkldnn_ops-inl.h"
#include "./mkldnn_base-inl.h"

#if MXNET_USE_MKLDNN == 1
namespace mxnet {
namespace op {

Expand Down Expand Up @@ -59,7 +60,7 @@ static mkldnn::convolution_forward::primitive_desc GetDeconvBwd_(
}
}

static mkldnn::convolution_backward_data::primitive_desc GetDeconvFwd(
static mkldnn::convolution_backward_data::primitive_desc GetDeconvFwdImpl(
const DeconvolutionParam& param, const NDArray &data, const NDArray &weights,
bool has_bias, const NDArray &output) {
auto data_md = GetMemDesc(data);
Expand All @@ -70,11 +71,21 @@ static mkldnn::convolution_backward_data::primitive_desc GetDeconvFwd(
if (param.stride.ndim() == 2) {
strides[0] = param.stride[0];
strides[1] = param.stride[1];
} else if (param.stride.ndim() == 1) {
strides[0] = param.stride[0];
strides[1] = param.stride[0];
} else {
LOG(FATAL) << "Unsupported stride dim";
}
mkldnn::memory::dims padding{0, 0};
if (param.pad.ndim() == 2) {
padding[0] = param.pad[0];
padding[1] = param.pad[1];
} else if (param.pad.ndim() == 1) {
padding[0] = param.pad[0];
padding[1] = param.pad[0];
} else {
LOG(FATAL) << "Unsupported pad dim";
}
mkldnn::memory::dims dilate{0, 0};
if (param.dilate.ndim() == 2) {
Expand All @@ -100,11 +111,21 @@ static mkldnn::convolution_forward::primitive_desc GetDeconvBwdData(
if (param.stride.ndim() == 2) {
strides[0] = param.stride[0];
strides[1] = param.stride[1];
} else if (param.stride.ndim() == 1) {
strides[0] = param.stride[0];
strides[1] = param.stride[0];
} else {
LOG(FATAL) << "Unsupported stride dim";
}
mkldnn::memory::dims padding{0, 0};
if (param.pad.ndim() == 2) {
padding[0] = param.pad[0];
padding[1] = param.pad[1];
} else if (param.pad.ndim() == 1) {
padding[0] = param.pad[0];
padding[1] = param.pad[0];
} else {
LOG(FATAL) << "Unsupported pad dim";
}
mkldnn::memory::dims dilate{0, 0};
if (param.dilate.ndim() == 2) {
Expand All @@ -127,11 +148,21 @@ static mkldnn::convolution_backward_weights::primitive_desc GetDeconvBwdWeights(
if (param.stride.ndim() == 2) {
strides[0] = param.stride[0];
strides[1] = param.stride[1];
} else if (param.stride.ndim() == 1) {
strides[0] = param.stride[0];
strides[1] = param.stride[0];
} else {
LOG(FATAL) << "Unsupported stride dim";
}
mkldnn::memory::dims padding{0, 0};
if (param.pad.ndim() == 2) {
padding[0] = param.pad[0];
padding[1] = param.pad[1];
} else if (param.pad.ndim() == 1) {
padding[0] = param.pad[0];
padding[1] = param.pad[0];
} else {
LOG(FATAL) << "Unsupported pad dim";
}
mkldnn::memory::dims dilate{0, 0};
if (param.dilate.ndim() == 2) {
Expand All @@ -151,42 +182,93 @@ static mkldnn::convolution_backward_weights::primitive_desc GetDeconvBwdWeights(
}
}

void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data) {
TmpMemMgr::Get()->Init(ctx.requested[deconv::kTempSpace]);
const DeconvolutionParam& param = nnvm::get<DeconvolutionParam>(attrs.parsed);
class MKLDNNDeconvForward {
std::shared_ptr<mkldnn::convolution_backward_data> fwd;
std::shared_ptr<mkldnn::memory> data;
std::shared_ptr<mkldnn::memory> weight;
std::shared_ptr<mkldnn::memory> bias;
std::shared_ptr<mkldnn::memory> out;
OutDataOp data_op;

public:
MKLDNNDeconvForward(const DeconvolutionParam& param,
const NDArray &data,
const NDArray &weights,
bool has_bias,
const NDArray &output);
void SetDataHandle(const DeconvolutionParam& param,
const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data);

void Execute(const std::vector<NDArray> &out_data);

private:
mkldnn::convolution_backward_data::primitive_desc fwd_pd;
}; // class MKLDNNDeconvForward

mkldnn::convolution_backward_data::primitive_desc deconvFwd_pd = GetDeconvFwd(
param, in_data[deconv::kData], in_data[deconv::kWeight], false,
out_data[deconv::kOut]);
MKLDNNDeconvForward::MKLDNNDeconvForward(const DeconvolutionParam& param,
const NDArray &data,
const NDArray &weights,
bool has_bias,
const NDArray &output)
:fwd_pd(GetDeconvFwdImpl(param, data, weights, has_bias, output)) {
this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
fwd_pd.diff_dst_primitive_desc()));
this->weight = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
fwd_pd.weights_primitive_desc()));
this->out = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
fwd_pd.diff_src_primitive_desc()));
this->fwd = std::shared_ptr<mkldnn::convolution_backward_data>(
new mkldnn::convolution_backward_data(fwd_pd,
mkldnn::primitive::at(*this->data),
mkldnn::primitive::at(*this->weight),
*this->out));
}

void MKLDNNDeconvForward::SetDataHandle(const DeconvolutionParam& param,
const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data) {
auto data_mem = in_data[deconv::kData].GetMKLDNNDataReorder(
deconvFwd_pd.diff_dst_primitive_desc());
fwd_pd.diff_dst_primitive_desc());
const mkldnn::memory *weight_mem;
if (ctx.is_train) {
// TODO(zhengda) kvstore doesn't handle MKLDNN correctly. Let's reorder it
// to the default format for now.
if (in_data[deconv::kWeight].IsMKLDNN())
const_cast<NDArray &>(in_data[deconv::kWeight]).Reorder2Default();
weight_mem = GetWeights(in_data[deconv::kWeight],
deconvFwd_pd.weights_primitive_desc(),
fwd_pd.weights_primitive_desc(),
param.num_group);
} else {
// For inference, we want to reorder the weight array so we don't need to
// reorder data every time.
const_cast<NDArray &>(in_data[deconv::kWeight]).Reorder(
deconvFwd_pd.weights_primitive_desc());
fwd_pd.weights_primitive_desc());
weight_mem = in_data[deconv::kWeight].GetMKLDNNData();
}
auto out_mem = CreateMKLDNNMem(out_data[deconv::kOut],
deconvFwd_pd.diff_src_primitive_desc(),
req[deconv::kOut]);
fwd_pd.diff_src_primitive_desc(), req[deconv::kOut]);
auto output = out_mem.second;
this->data->set_data_handle(data_mem->get_data_handle());
this->weight->set_data_handle(weight_mem->get_data_handle());
this->out->set_data_handle(output->get_data_handle());
this->data_op = out_mem.first;
}

MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_data(
deconvFwd_pd, *data_mem, *weight_mem, *out_mem.second));
CommitOutput(out_data[deconv::kOut], out_mem);
void MKLDNNDeconvForward::Execute(const std::vector<NDArray> &out_data) {
MKLDNNStream::Get()->RegisterPrim(*fwd);
CommitOutput(out_data[deconv::kOut], mkldnn_output_t(this->data_op, this->out.get()));
MKLDNNStream::Get()->Submit();
}

static void MKLDNNDeconvFwdBiasPostProcess(const DeconvolutionParam& param,
const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<NDArray> &out_data) {
// add bias, broadcast bias to dim 1: channel
if (!param.no_bias) {
// MKLDNN only supports float right now.
Expand All @@ -201,6 +283,55 @@ void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &c
}
}

typedef MKLDNNParamOpSign<DeconvolutionParam> MKLDNNDeconvSignature;

static inline MKLDNNDeconvForward &GetDeconvFwd(
const nnvm::NodeAttrs& attrs, const NDArray &data,
const NDArray &weights, const NDArray *bias,
const NDArray &output) {
static thread_local
std::unordered_map<MKLDNNDeconvSignature, MKLDNNDeconvForward, MKLDNNOpHash> fwds;
const DeconvolutionParam& param = nnvm::get<DeconvolutionParam>(attrs.parsed);
MKLDNNDeconvSignature key(param);
// Here we can sign the conv op with NDArray because conv primitive will
// decide the right layout for the, so we only need to get the shape and the
// data type of the arrays.
key.AddSign(data);
key.AddSign(weights);
key.AddSign(output);
if (bias)
key.AddSign(*bias);

auto it = fwds.find(key);
if (it == fwds.end()) {
bool has_bias = (bias != nullptr);
MKLDNNDeconvForward fwd(param, data, weights, has_bias, output);
auto ins_ret = fwds.insert(
std::pair<MKLDNNDeconvSignature, MKLDNNDeconvForward>(key, fwd));
CHECK(ins_ret.second);
it = ins_ret.first;
}
return it->second;
}

void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data) {
TmpMemMgr::Get()->Init(ctx.requested[deconv::kTempSpace]);
const DeconvolutionParam& param = nnvm::get<DeconvolutionParam>(attrs.parsed);

MKLDNNDeconvForward &deconvFwd = GetDeconvFwd(
attrs, in_data[deconv::kData], in_data[deconv::kWeight],
param.no_bias ? nullptr : &in_data[deconv::kBias], out_data[deconv::kOut]);

deconvFwd.SetDataHandle(param, ctx, in_data, req, out_data);

deconvFwd.Execute(out_data);

MKLDNNDeconvFwdBiasPostProcess(param, ctx, in_data, out_data);
}

void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
Expand Down

0 comments on commit ea46c2f

Please sign in to comment.