Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rebase and Refactor deconv #20

Merged
merged 2 commits into from
Jan 15, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/operator/nn/deconvolution-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,31 @@ struct DeconvolutionParam : public dmlc::Parameter<DeconvolutionParam> {
this->cudnn_off == other.cudnn_off &&
this->layout == other.layout;
}
#if MXNET_USE_MKLDNN == 1
static uint64_t ComputeHash(const TShape &shape) {
uint64_t hash = 0;
for (size_t i = 0; i < shape.ndim(); i++)
hash = hash * 2 + shape[i];
return hash;
}

uint64_t GetHash() const {
uint64_t hash = 0;
hash = hash * 2 + ComputeHash(kernel);
hash = hash * 2 + ComputeHash(stride);
hash = hash * 2 + ComputeHash(dilate);
hash = hash * 2 + ComputeHash(pad);
hash = hash * 2 + ComputeHash(adj);
hash = hash * 2 + ComputeHash(target_shape);
hash = hash * 2 + num_filter;
hash = hash * 2 + num_group;
hash = hash * 2 + workspace;
hash = hash * 2 + no_bias;
if (layout.has_value())
hash = hash * 2 + layout.value();
return hash;
}
#endif
};

} // namespace op
Expand Down
171 changes: 151 additions & 20 deletions src/operator/nn/mkldnn/mkldnn_deconvolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@
/*!
* \file mkldnn_deconvolution.cc
* \brief
* \author Da Zheng
* \author Da Zheng, Rong Zhang ([email protected])
*/

#if MXNET_USE_MKLDNN == 1

#include "../deconvolution-inl.h"
#include "./mkldnn_ops-inl.h"
#include "./mkldnn_base-inl.h"

#if MXNET_USE_MKLDNN == 1
namespace mxnet {
namespace op {

Expand Down Expand Up @@ -59,7 +60,7 @@ static mkldnn::convolution_forward::primitive_desc GetDeconvBwd_(
}
}

static mkldnn::convolution_backward_data::primitive_desc GetDeconvFwd(
static mkldnn::convolution_backward_data::primitive_desc GetDeconvFwdImpl(
const DeconvolutionParam& param, const NDArray &data, const NDArray &weights,
bool has_bias, const NDArray &output) {
auto data_md = GetMemDesc(data);
Expand All @@ -70,11 +71,21 @@ static mkldnn::convolution_backward_data::primitive_desc GetDeconvFwd(
if (param.stride.ndim() == 2) {
strides[0] = param.stride[0];
strides[1] = param.stride[1];
} else if (param.stride.ndim() == 1) {
strides[0] = param.stride[0];
strides[1] = param.stride[0];
} else {
LOG(FATAL) << "Unsupported stride dim";
}
mkldnn::memory::dims padding{0, 0};
if (param.pad.ndim() == 2) {
padding[0] = param.pad[0];
padding[1] = param.pad[1];
} else if (param.pad.ndim() == 1) {
padding[0] = param.pad[0];
padding[1] = param.pad[0];
} else {
LOG(FATAL) << "Unsupported pad dim";
}
mkldnn::memory::dims dilate{0, 0};
if (param.dilate.ndim() == 2) {
Expand All @@ -100,11 +111,21 @@ static mkldnn::convolution_forward::primitive_desc GetDeconvBwdData(
if (param.stride.ndim() == 2) {
strides[0] = param.stride[0];
strides[1] = param.stride[1];
} else if (param.stride.ndim() == 1) {
strides[0] = param.stride[0];
strides[1] = param.stride[0];
} else {
LOG(FATAL) << "Unsupported stride dim";
}
mkldnn::memory::dims padding{0, 0};
if (param.pad.ndim() == 2) {
padding[0] = param.pad[0];
padding[1] = param.pad[1];
} else if (param.pad.ndim() == 1) {
padding[0] = param.pad[0];
padding[1] = param.pad[0];
} else {
LOG(FATAL) << "Unsupported pad dim";
}
mkldnn::memory::dims dilate{0, 0};
if (param.dilate.ndim() == 2) {
Expand All @@ -127,11 +148,21 @@ static mkldnn::convolution_backward_weights::primitive_desc GetDeconvBwdWeights(
if (param.stride.ndim() == 2) {
strides[0] = param.stride[0];
strides[1] = param.stride[1];
} else if (param.stride.ndim() == 1) {
strides[0] = param.stride[0];
strides[1] = param.stride[0];
} else {
LOG(FATAL) << "Unsupported stride dim";
}
mkldnn::memory::dims padding{0, 0};
if (param.pad.ndim() == 2) {
padding[0] = param.pad[0];
padding[1] = param.pad[1];
} else if (param.pad.ndim() == 1) {
padding[0] = param.pad[0];
padding[1] = param.pad[0];
} else {
LOG(FATAL) << "Unsupported pad dim";
}
mkldnn::memory::dims dilate{0, 0};
if (param.dilate.ndim() == 2) {
Expand All @@ -151,42 +182,93 @@ static mkldnn::convolution_backward_weights::primitive_desc GetDeconvBwdWeights(
}
}

void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data) {
TmpMemMgr::Get()->Init(ctx.requested[deconv::kTempSpace]);
const DeconvolutionParam& param = nnvm::get<DeconvolutionParam>(attrs.parsed);
class MKLDNNDeconvForward {
std::shared_ptr<mkldnn::convolution_backward_data> fwd;
std::shared_ptr<mkldnn::memory> data;
std::shared_ptr<mkldnn::memory> weight;
std::shared_ptr<mkldnn::memory> bias;
std::shared_ptr<mkldnn::memory> out;
OutDataOp data_op;

public:
MKLDNNDeconvForward(const DeconvolutionParam& param,
const NDArray &data,
const NDArray &weights,
bool has_bias,
const NDArray &output);
void SetDataHandle(const DeconvolutionParam& param,
const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data);

void Execute(const std::vector<NDArray> &out_data);

private:
mkldnn::convolution_backward_data::primitive_desc fwd_pd;
}; // class MKLDNNDeconvForward

mkldnn::convolution_backward_data::primitive_desc deconvFwd_pd = GetDeconvFwd(
param, in_data[deconv::kData], in_data[deconv::kWeight], false,
out_data[deconv::kOut]);
MKLDNNDeconvForward::MKLDNNDeconvForward(const DeconvolutionParam& param,
const NDArray &data,
const NDArray &weights,
bool has_bias,
const NDArray &output)
:fwd_pd(GetDeconvFwdImpl(param, data, weights, has_bias, output)) {
this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
fwd_pd.diff_dst_primitive_desc()));
this->weight = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
fwd_pd.weights_primitive_desc()));
this->out = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
fwd_pd.diff_src_primitive_desc()));
this->fwd = std::shared_ptr<mkldnn::convolution_backward_data>(
new mkldnn::convolution_backward_data(fwd_pd,
mkldnn::primitive::at(*this->data),
mkldnn::primitive::at(*this->weight),
*this->out));
}

void MKLDNNDeconvForward::SetDataHandle(const DeconvolutionParam& param,
const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data) {
auto data_mem = in_data[deconv::kData].GetMKLDNNDataReorder(
deconvFwd_pd.diff_dst_primitive_desc());
fwd_pd.diff_dst_primitive_desc());
const mkldnn::memory *weight_mem;
if (ctx.is_train) {
// TODO(zhengda) kvstore doesn't handle MKLDNN correctly. Let's reorder it
// to the default format for now.
if (in_data[deconv::kWeight].IsMKLDNN())
const_cast<NDArray &>(in_data[deconv::kWeight]).Reorder2Default();
weight_mem = GetWeights(in_data[deconv::kWeight],
deconvFwd_pd.weights_primitive_desc(),
fwd_pd.weights_primitive_desc(),
param.num_group);
} else {
// For inference, we want to reorder the weight array so we don't need to
// reorder data every time.
const_cast<NDArray &>(in_data[deconv::kWeight]).Reorder(
deconvFwd_pd.weights_primitive_desc());
fwd_pd.weights_primitive_desc());
weight_mem = in_data[deconv::kWeight].GetMKLDNNData();
}
auto out_mem = CreateMKLDNNMem(out_data[deconv::kOut],
deconvFwd_pd.diff_src_primitive_desc(),
req[deconv::kOut]);
fwd_pd.diff_src_primitive_desc(), req[deconv::kOut]);
auto output = out_mem.second;
this->data->set_data_handle(data_mem->get_data_handle());
this->weight->set_data_handle(weight_mem->get_data_handle());
this->out->set_data_handle(output->get_data_handle());
this->data_op = out_mem.first;
}

MKLDNNStream::Get()->RegisterPrim(mkldnn::convolution_backward_data(
deconvFwd_pd, *data_mem, *weight_mem, *out_mem.second));
CommitOutput(out_data[deconv::kOut], out_mem);
void MKLDNNDeconvForward::Execute(const std::vector<NDArray> &out_data) {
MKLDNNStream::Get()->RegisterPrim(*fwd);
CommitOutput(out_data[deconv::kOut], mkldnn_output_t(this->data_op, this->out.get()));
MKLDNNStream::Get()->Submit();
}

static void MKLDNNDeconvFwdBiasPostProcess(const DeconvolutionParam& param,
const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<NDArray> &out_data) {
// add bias, broadcast bias to dim 1: channel
if (!param.no_bias) {
// MKLDNN only supports float right now.
Expand All @@ -201,6 +283,55 @@ void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &c
}
}

typedef MKLDNNParamOpSign<DeconvolutionParam> MKLDNNDeconvSignature;

static inline MKLDNNDeconvForward &GetDeconvFwd(
const nnvm::NodeAttrs& attrs, const NDArray &data,
const NDArray &weights, const NDArray *bias,
const NDArray &output) {
static thread_local
std::unordered_map<MKLDNNDeconvSignature, MKLDNNDeconvForward, MKLDNNOpHash> fwds;
const DeconvolutionParam& param = nnvm::get<DeconvolutionParam>(attrs.parsed);
MKLDNNDeconvSignature key(param);
// Here we can sign the conv op with NDArray because conv primitive will
// decide the right layout for the, so we only need to get the shape and the
// data type of the arrays.
key.AddSign(data);
key.AddSign(weights);
key.AddSign(output);
if (bias)
key.AddSign(*bias);

auto it = fwds.find(key);
if (it == fwds.end()) {
bool has_bias = (bias != nullptr);
MKLDNNDeconvForward fwd(param, data, weights, has_bias, output);
auto ins_ret = fwds.insert(
std::pair<MKLDNNDeconvSignature, MKLDNNDeconvForward>(key, fwd));
CHECK(ins_ret.second);
it = ins_ret.first;
}
return it->second;
}

void MKLDNNDeconvolutionForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data) {
TmpMemMgr::Get()->Init(ctx.requested[deconv::kTempSpace]);
const DeconvolutionParam& param = nnvm::get<DeconvolutionParam>(attrs.parsed);

MKLDNNDeconvForward &deconvFwd = GetDeconvFwd(
attrs, in_data[deconv::kData], in_data[deconv::kWeight],
param.no_bias ? nullptr : &in_data[deconv::kBias], out_data[deconv::kOut]);

deconvFwd.SetDataHandle(param, ctx, in_data, req, out_data);

deconvFwd.Execute(out_data);

MKLDNNDeconvFwdBiasPostProcess(param, ctx, in_data, out_data);
}

void MKLDNNDeconvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
Expand Down