Skip to content

Commit

Permalink
Handle kAddTo in MKLDNN operators.
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da committed Nov 7, 2017
1 parent 13fcb9b commit beb8505
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 126 deletions.
38 changes: 38 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,44 @@ inline static mkldnn_mem_ptr CreateMKLDNNMem(const mkldnn::memory::primitive_des
return ret;
}

enum OutDataOp {
Noop,
CopyBack,
AddBack,
};

typedef std::pair<OutDataOp, mkldnn_mem_ptr> mkldnn_output_t;

static inline mkldnn_output_t CreateMKLDNNMem(const NDArray &arr,
const mkldnn::memory::primitive_desc &desc, OpReqType req) {
if (kAddTo == req)
return mkldnn_output_t(OutDataOp::AddBack, CreateMKLDNNMem(desc));
else {
mkldnn_mem_ptr mem = const_cast<NDArray &>(arr).CreateMKLDNNData(desc);
if (mem == nullptr)
return mkldnn_output_t(OutDataOp::CopyBack, CreateMKLDNNMem(desc));
else
return mkldnn_output_t(OutDataOp::Noop, mem);
}
}

namespace op {
void Sum(const mkldnn::memory &arr1, const mkldnn::memory &arr2,
const mkldnn::memory &out);
}

static inline void CommitOutput(const NDArray &arr, const mkldnn_output_t &res) {
if (res.first == CopyBack)
const_cast<NDArray &>(arr).CopyFrom(*res.second);
else if (res.first == AddBack) {
// TODO I might need to reorder.
mkldnn_mem_const_ptr mem = arr.GetMKLDNNData(res.second->get_primitive_desc());
mkldnn_mem_ptr out = CreateMKLDNNMem(res.second->get_primitive_desc());
op::Sum(*res.second, *mem, *out);
const_cast<NDArray &>(arr).CopyFrom(*out);
}
}

inline static mkldnn_mem_const_ptr GetWeights(const NDArray &arr,
const mkldnn::memory::primitive_desc &target_pd, int num_groups) {
mkldnn_mem_const_ptr mem;
Expand Down
55 changes: 19 additions & 36 deletions src/operator/nn/mkldnn/mkldnn_convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,18 +182,18 @@ void MKLDNNConvolution_Forward(const nnvm::NodeAttrs& attrs, const OpContext &ct
auto engine = CpuEngine::Instance().get_engine();
auto weight_mem = GetWeights(in_data[conv::kWeight],
fwd_pd.weights_primitive_desc(), param.num_group);

auto out_mem = const_cast<NDArray &>(out_data[conv::kOut]).CreateMKLDNNData(
fwd_pd.dst_primitive_desc());
auto out_mem = CreateMKLDNNMem(out_data[conv::kOut],
fwd_pd.dst_primitive_desc(), req[conv::kOut]);

if (param.no_bias) {
MKLDNNStream::Instance().RegisterPrim(mkldnn::convolution_forward(fwd_pd,
*data_mem, *weight_mem, *out_mem));
*data_mem, *weight_mem, *out_mem.second));
} else {
auto bias_mem = in_data[conv::kBias].GetMKLDNNDataReorder(fwd_pd.bias_primitive_desc());
MKLDNNStream::Instance().RegisterPrim(mkldnn::convolution_forward(fwd_pd,
*data_mem, *weight_mem, *bias_mem, *out_mem));
*data_mem, *weight_mem, *bias_mem, *out_mem.second));
}
CommitOutput(out_data[conv::kOut], out_mem);
MKLDNNStream::Instance().Submit();
}

Expand All @@ -216,17 +216,11 @@ void MKLDNNConvolution_Backward(const nnvm::NodeAttrs& attrs, const OpContext &c
bwdData_pd.diff_dst_primitive_desc());
auto weight_mem = GetWeights(inputs[conv::kWeight + 1],
bwdData_pd.weights_primitive_desc(), param.num_group);
auto in_grad_mem = const_cast<NDArray &>(in_grad[conv::kData]).CreateMKLDNNData(
bwdData_pd.diff_src_primitive_desc());
bool copy_back = false;
if (in_grad_mem == nullptr) {
in_grad_mem = CreateMKLDNNMem(bwdData_pd.diff_src_primitive_desc());
copy_back = true;
}
auto in_grad_mem = CreateMKLDNNMem(in_grad[conv::kData],
bwdData_pd.diff_src_primitive_desc(), req[conv::kData]);
MKLDNNStream::Instance().RegisterPrim(mkldnn::convolution_backward_data(bwdData_pd,
*out_grad_mem, *weight_mem, *in_grad_mem));
if (copy_back)
const_cast<NDArray &>(in_grad[conv::kData]).CopyFrom(*in_grad_mem);
*out_grad_mem, *weight_mem, *in_grad_mem.second));
CommitOutput(in_grad[conv::kData], in_grad_mem);
}
if (req[conv::kWeight]) {
mkldnn::convolution_backward_weights::primitive_desc bwdWeights_pd
Expand All @@ -236,32 +230,21 @@ void MKLDNNConvolution_Backward(const nnvm::NodeAttrs& attrs, const OpContext &c
bwdWeights_pd.diff_dst_primitive_desc());
auto data_mem = inputs[conv::kData + 1].GetMKLDNNDataReorder(
bwdWeights_pd.src_primitive_desc());
auto in_grad_weight = const_cast<NDArray &>(in_grad[conv::kWeight]).CreateMKLDNNData(
bwdWeights_pd.diff_weights_primitive_desc());
bool copy_back_weight = false;
bool copy_back_bias = false;
if (in_grad_weight == nullptr) {
in_grad_weight = CreateMKLDNNMem(bwdWeights_pd.diff_weights_primitive_desc());
copy_back_weight = true;
}
mkldnn_mem_const_ptr in_grad_bias;
auto in_grad_weight = CreateMKLDNNMem(in_grad[conv::kWeight],
bwdWeights_pd.diff_weights_primitive_desc(), req[conv::kWeight]);
mkldnn_output_t in_grad_bias;
if (param.no_bias) {
MKLDNNStream::Instance().RegisterPrim(mkldnn::convolution_backward_weights(
bwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight));
bwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second));
} else {
in_grad_bias = const_cast<NDArray &>(in_grad[conv::kBias]).CreateMKLDNNData(
bwdWeights_pd.diff_bias_primitive_desc());
if (in_grad_bias == nullptr) {
in_grad_bias = CreateMKLDNNMem(bwdWeights_pd.diff_bias_primitive_desc());
copy_back_bias = true;
}
in_grad_bias = CreateMKLDNNMem(in_grad[conv::kBias],
bwdWeights_pd.diff_bias_primitive_desc(), req[conv::kBias]);
MKLDNNStream::Instance().RegisterPrim(mkldnn::convolution_backward_weights(
bwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight, *in_grad_bias));
bwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second,
*in_grad_bias.second));
}
if (copy_back_weight)
const_cast<NDArray &>(in_grad[conv::kWeight]).CopyFrom(*in_grad_weight);
if (copy_back_bias)
const_cast<NDArray &>(in_grad[conv::kBias]).CopyFrom(*in_grad_bias);
CommitOutput(in_grad[conv::kWeight], in_grad_weight);
CommitOutput(in_grad[conv::kBias], in_grad_bias);
}
MKLDNNStream::Instance().Submit();
}
Expand Down
61 changes: 18 additions & 43 deletions src/operator/nn/mkldnn/mkldnn_deconvolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,18 +179,12 @@ void MKLDNNDeconvolution_Forward(const nnvm::NodeAttrs& attrs, const OpContext &
deconvFwd_pd.diff_src_primitive_desc());
auto weight_mem = GetWeights(in_data[deconv::kWeight],
deconvFwd_pd.weights_primitive_desc(), param.num_group);
auto out_mem = const_cast<NDArray &>(out_data[deconv::kOut]).CreateMKLDNNData(
deconvFwd_pd.diff_dst_primitive_desc());
bool copy_back = false;
if (out_mem == nullptr) {
out_mem = CreateMKLDNNMem(deconvFwd_pd.diff_dst_primitive_desc());
copy_back = true;
}
auto out_mem = CreateMKLDNNMem(out_data[deconv::kOut],
deconvFwd_pd.diff_dst_primitive_desc(), req[deconv::kOut]);

MKLDNNStream::Instance().RegisterPrim(mkldnn::convolution_backward_data(
deconvFwd_pd, *data_mem, *weight_mem, *out_mem));
if (copy_back)
const_cast<NDArray &>(out_data[deconv::kOut]).CopyFrom(*out_mem);
deconvFwd_pd, *data_mem, *weight_mem, *out_mem.second));
CommitOutput(out_data[deconv::kOut], out_mem);
MKLDNNStream::Instance().Submit();
if (!param.no_bias) {
// add bias, broadcast bias to dim 1: channel
Expand All @@ -209,7 +203,6 @@ void MKLDNNDeconvolution_Backward(const nnvm::NodeAttrs& attrs, const OpContext
const std::vector<NDArray>& outputs) {
const std::vector<NDArray> &in_grad = outputs;
const DeconvolutionParam& param = nnvm::get<DeconvolutionParam>(attrs.parsed);

CHECK_NE(req[deconv::kWeight], kWriteInplace) << "cannot write weight inplace";
mkldnn::convolution_forward::primitive_desc bwdData_pd = GetDeconvBwdData(
param, inputs[deconv::kData + 1], inputs[deconv::kWeight + 1], nullptr,
Expand All @@ -219,55 +212,37 @@ void MKLDNNDeconvolution_Backward(const nnvm::NodeAttrs& attrs, const OpContext
bwdData_pd.src_primitive_desc());
auto weight_mem = GetWeights(inputs[deconv::kWeight + 1],
bwdData_pd.weights_primitive_desc(), param.num_group);
auto in_grad_mem = const_cast<NDArray &>(in_grad[deconv::kData]).CreateMKLDNNData(
bwdData_pd.dst_primitive_desc());
bool copy_back = false;
if (in_grad_mem == nullptr) {
in_grad_mem = CreateMKLDNNMem(bwdData_pd.dst_primitive_desc());
copy_back = true;
}
auto in_grad_mem = CreateMKLDNNMem(in_grad[deconv::kData],
bwdData_pd.dst_primitive_desc(), req[deconv::kData]);
MKLDNNStream::Instance().RegisterPrim(mkldnn::convolution_forward(bwdData_pd,
*out_grad_mem, *weight_mem, *in_grad_mem));
if (copy_back)
const_cast<NDArray &>(in_grad[deconv::kData]).CopyFrom(*in_grad_mem);
*out_grad_mem, *weight_mem, *in_grad_mem.second));
CommitOutput(in_grad[deconv::kData], in_grad_mem);
}
if (req[deconv::kWeight]) {
mkldnn::convolution_backward_weights::primitive_desc bwdWeights_pd
= GetDeconvBwdWeights(param, inputs[deconv::kData + 1],
inputs[deconv::kWeight + 1],
param.no_bias ? nullptr : &inputs[deconv::kWeight + 1],
inputs[deconv::kOut], bwdData_pd);
CHECK_NE(req[deconv::kWeight], kAddTo);
auto out_grad_mem = inputs[deconv::kOut].GetMKLDNNDataReorder(
bwdWeights_pd.diff_dst_primitive_desc());
auto data_mem = inputs[deconv::kData + 1].GetMKLDNNDataReorder(
bwdWeights_pd.src_primitive_desc());
auto in_grad_weight = const_cast<NDArray &>(in_grad[deconv::kWeight]).CreateMKLDNNData(
bwdWeights_pd.diff_weights_primitive_desc());
bool copy_back_weight = false;
bool copy_back_bias = false;
if (in_grad_weight == nullptr) {
in_grad_weight = CreateMKLDNNMem(bwdWeights_pd.diff_weights_primitive_desc());
copy_back_weight = true;
}
mkldnn_mem_const_ptr in_grad_bias;
auto in_grad_weight = CreateMKLDNNMem(in_grad[deconv::kWeight],
bwdWeights_pd.diff_weights_primitive_desc(), req[deconv::kWeight]);
mkldnn_output_t in_grad_bias;
if (param.no_bias) {
MKLDNNStream::Instance().RegisterPrim(mkldnn::convolution_backward_weights(
bwdWeights_pd, *out_grad_mem, *data_mem, *in_grad_weight));
bwdWeights_pd, *out_grad_mem, *data_mem, *in_grad_weight.second));
} else {
in_grad_bias = const_cast<NDArray &>(in_grad[deconv::kBias]).CreateMKLDNNData(
bwdWeights_pd.diff_bias_primitive_desc());
if (in_grad_bias == nullptr) {
in_grad_bias = CreateMKLDNNMem(bwdWeights_pd.diff_bias_primitive_desc());
copy_back_bias = true;
}
in_grad_bias = CreateMKLDNNMem(in_grad[deconv::kBias],
bwdWeights_pd.diff_bias_primitive_desc(), req[deconv::kBias]);
MKLDNNStream::Instance().RegisterPrim(mkldnn::convolution_backward_weights(
bwdWeights_pd, *out_grad_mem, *data_mem, *in_grad_weight, *in_grad_bias));
bwdWeights_pd, *out_grad_mem, *data_mem, *in_grad_weight.second,
*in_grad_bias.second));
}
if (copy_back_weight)
const_cast<NDArray &>(in_grad[deconv::kWeight]).CopyFrom(*in_grad_weight);
if (copy_back_bias)
const_cast<NDArray &>(in_grad[deconv::kBias]).CopyFrom(*in_grad_bias);
CommitOutput(in_grad[deconv::kWeight], in_grad_weight);
CommitOutput(in_grad[deconv::kBias], in_grad_bias);
}
MKLDNNStream::Instance().Submit();
}
Expand Down
61 changes: 19 additions & 42 deletions src/operator/nn/mkldnn/mkldnn_fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -93,23 +93,17 @@ void MKLDNNFC_Forward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
auto data_mem = in_data[fullc::kData].GetMKLDNNDataReorder(ipFwd_pd.src_primitive_desc());
auto weight_mem = in_data[fullc::kWeight].GetMKLDNNDataReorder(
ipFwd_pd.weights_primitive_desc());
auto out_mem = const_cast<NDArray &>(out_data[fullc::kOut]).CreateMKLDNNData(
ipFwd_pd.dst_primitive_desc());
bool copy_back = false;
if (out_mem == nullptr) {
out_mem = CreateMKLDNNMem(ipFwd_pd.dst_primitive_desc());
copy_back = true;
}
auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut],
ipFwd_pd.dst_primitive_desc(), req[fullc::kOut]);
if (param.no_bias) {
MKLDNNStream::Instance().RegisterPrim(mkldnn::inner_product_forward(
ipFwd_pd, *data_mem, *weight_mem, *out_mem));
ipFwd_pd, *data_mem, *weight_mem, *out_mem.second));
} else {
auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder(ipFwd_pd.bias_primitive_desc());
MKLDNNStream::Instance().RegisterPrim(mkldnn::inner_product_forward(ipFwd_pd,
*data_mem, *weight_mem, *bias_mem, *out_mem));
*data_mem, *weight_mem, *bias_mem, *out_mem.second));
}
if (copy_back)
const_cast<NDArray &>(out_data[fullc::kOut]).CopyFrom(*out_mem);
CommitOutput(out_data[fullc::kOut], out_mem);
MKLDNNStream::Instance().Submit();
}

Expand All @@ -131,17 +125,11 @@ void MKLDNNFC_Backward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
ipBwdData_pd.diff_dst_primitive_desc());
auto weight_mem = inputs[fullc::kWeight + 1].GetMKLDNNDataReorder(
ipBwdData_pd.weights_primitive_desc());
auto in_grad_mem = const_cast<NDArray &>(in_grad[fullc::kData]).CreateMKLDNNData(
ipBwdData_pd.diff_src_primitive_desc());
bool copy_back = false;
if (in_grad_mem == nullptr) {
in_grad_mem = CreateMKLDNNMem(ipBwdData_pd.diff_src_primitive_desc());
copy_back = true;
}
auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData],
ipBwdData_pd.diff_src_primitive_desc(), req[fullc::kData]);
MKLDNNStream::Instance().RegisterPrim(mkldnn::inner_product_backward_data(
ipBwdData_pd, *out_grad_mem, *weight_mem, *in_grad_mem));
if (copy_back)
const_cast<NDArray &>(in_grad[fullc::kData]).CopyFrom(*in_grad_mem);
ipBwdData_pd, *out_grad_mem, *weight_mem, *in_grad_mem.second));
CommitOutput(in_grad[fullc::kData], in_grad_mem);
}
if (req[fullc::kWeight]) {
mkldnn::inner_product_backward_weights::primitive_desc ipBwdWeights_pd
Expand All @@ -152,32 +140,21 @@ void MKLDNNFC_Backward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
ipBwdWeights_pd.diff_dst_primitive_desc());
auto data_mem = inputs[fullc::kData + 1].GetMKLDNNDataReorder(
ipBwdWeights_pd.src_primitive_desc());
auto in_grad_weight = const_cast<NDArray &>(in_grad[fullc::kWeight]).CreateMKLDNNData(
ipBwdWeights_pd.diff_weights_primitive_desc());
bool copy_back_weight = false;
bool copy_back_bias = false;
if (in_grad_weight == nullptr) {
in_grad_weight = CreateMKLDNNMem(ipBwdWeights_pd.diff_weights_primitive_desc());
copy_back_weight = true;
}
mkldnn_mem_const_ptr in_grad_bias;
auto in_grad_weight = CreateMKLDNNMem(in_grad[fullc::kWeight],
ipBwdWeights_pd.diff_weights_primitive_desc(), req[fullc::kWeight]);
mkldnn_output_t in_grad_bias;
if (param.no_bias) {
MKLDNNStream::Instance().RegisterPrim(mkldnn::inner_product_backward_weights(
ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight));
ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second));
} else {
in_grad_bias = const_cast<NDArray &>(in_grad[fullc::kBias]).CreateMKLDNNData(
ipBwdWeights_pd.diff_bias_primitive_desc());
if (in_grad_bias == nullptr) {
in_grad_bias = CreateMKLDNNMem(ipBwdWeights_pd.diff_bias_primitive_desc());
copy_back_bias = true;
}
in_grad_bias = CreateMKLDNNMem(in_grad[fullc::kBias],
ipBwdWeights_pd.diff_bias_primitive_desc(), req[fullc::kBias]);
MKLDNNStream::Instance().RegisterPrim(mkldnn::inner_product_backward_weights(
ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight, *in_grad_bias));
ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second,
*in_grad_bias.second));
}
if (copy_back_weight)
const_cast<NDArray &>(in_grad[fullc::kWeight]).CopyFrom(*in_grad_weight);
if (copy_back_bias)
const_cast<NDArray &>(in_grad[fullc::kBias]).CopyFrom(*in_grad_bias);
CommitOutput(in_grad[fullc::kWeight], in_grad_weight);
CommitOutput(in_grad[fullc::kBias], in_grad_bias);
}
MKLDNNStream::Instance().Submit();
}
Expand Down
8 changes: 3 additions & 5 deletions src/operator/nn/mkldnn/mkldnn_relu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ void MKLDNNRelu_Backward(const OpContext &ctx, const NDArray &out_grad,
return;
}

// TODO we need to handle req
std::shared_ptr<const mkldnn::memory> diff_dst_memory = out_grad.GetMKLDNNData();
// TODO shouldn't it be out_data?
std::shared_ptr<const mkldnn::memory> input_mem = in_data.GetMKLDNNData();
mkldnn::memory::primitive_desc data_mpd = input_mem->get_primitive_desc();
mkldnn::memory::desc data_md = data_mpd.desc();
Expand All @@ -92,11 +90,11 @@ void MKLDNNRelu_Backward(const OpContext &ctx, const NDArray &out_grad,
mkldnn::eltwise_backward::desc bw_desc(mkldnn::eltwise_relu, diff_md, data_md, alpha);
mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine, fw_pdesc);

std::shared_ptr<const mkldnn::memory> diff_src_memory
= const_cast<NDArray &>(in_grad).CreateMKLDNNData(bw_pdesc.diff_src_primitive_desc());
auto diff_src_memory = CreateMKLDNNMem(in_grad, bw_pdesc.diff_src_primitive_desc(), req);
MKLDNNStream &stream = MKLDNNStream::Instance();
stream.RegisterPrim(mkldnn::eltwise_backward(bw_pdesc, *input_mem,
*diff_dst_memory, *diff_src_memory));
*diff_dst_memory, *diff_src_memory.second));
CommitOutput(in_grad, diff_src_memory);
stream.Submit();
}

Expand Down
Loading

0 comments on commit beb8505

Please sign in to comment.