Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add primitive cache for MKL-DNN sum(elemwise_add operator #14914

Merged
merged 2 commits into from
May 15, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 87 additions & 18 deletions src/operator/nn/mkldnn/mkldnn_sum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
*/
#include <iostream>

#include "../../operator_common.h"
#include "./mkldnn_ops-inl.h"
#include "./mkldnn_base-inl.h"

Expand Down Expand Up @@ -58,37 +59,105 @@ void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2,
MKLDNNStream::Get()->RegisterPrim(mkldnn::sum(sum_pd, inputs, out));
}

class MKLDNNSumFwd {
public:
mkldnn::sum::primitive_desc fwd_pd;

MKLDNNSumFwd(const std::vector<float> &scales,
const std::vector<mkldnn::memory::primitive_desc> &data_md)
: fwd_pd(scales, data_md) {
data_.resize(data_md.size());
}

void SetNewMem(const std::vector<const mkldnn::memory *> &in_data, const mkldnn::memory &output);

const mkldnn::sum &GetFwd() const { return *fwd_; }

private:
std::shared_ptr<mkldnn::sum> fwd_;
std::vector<std::shared_ptr<mkldnn::memory>> data_;
std::vector<mkldnn::primitive::at> data_mem_;
std::shared_ptr<mkldnn::memory> out_;
};

static MKLDNNSumFwd &GetSumForward(
const std::vector<float> &scales, const std::vector<NDArray> &in_data,
const std::vector<mkldnn::memory::primitive_desc> &data_md) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<OpSignature, MKLDNNSumFwd, OpHash> fwds;
#else
static MX_THREAD_LOCAL std::unordered_map<OpSignature, MKLDNNSumFwd, OpHash> fwds;
#endif
OpSignature key;
key.AddSign(in_data);

auto it = fwds.find(key);
if (it == fwds.end()) {
MKLDNNSumFwd fwd(scales, data_md);
it = AddToCache(&fwds, key, fwd);
}
return it->second;
}

void MKLDNNSumFwd::SetNewMem(const std::vector<const mkldnn::memory *> &in_data,
const mkldnn::memory &output) {
auto num_inputs = data_.size();
CHECK_EQ(in_data.size(), num_inputs);
for (index_t i = 0; i < static_cast<index_t>(num_inputs); ++i) {
if (this->data_[i] == nullptr) {
this->data_[i] = std::shared_ptr<mkldnn::memory>(
new mkldnn::memory(in_data[i]->get_primitive_desc(), in_data[i]->get_data_handle()));
this->data_mem_.push_back(*this->data_[i]);
} else {
this->data_[i]->set_data_handle(in_data[i]->get_data_handle());
}
}
if (this->out_ == nullptr)
this->out_ = std::shared_ptr<mkldnn::memory>(
new mkldnn::memory(fwd_pd.dst_primitive_desc(), output.get_data_handle()));
else
this->out_->set_data_handle(output.get_data_handle());

if (this->fwd_ == nullptr)
this->fwd_.reset(new mkldnn::sum(fwd_pd, this->data_mem_, *this->out_));
}

void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &inputs, const OpReqType &req,
const NDArray &out_data) {
if (req == kNullOp) {
return;
}

TmpMemMgr::Get()->Init(ctx.requested[0]);
std::vector<mkldnn::primitive::at> in_prims;
std::vector<mkldnn::memory::primitive_desc> in_pds(inputs.size());
std::vector<float> scales(inputs.size(), 1);
in_prims.reserve(inputs.size());
std::vector<NDArray> in_bufs(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
auto num_inputs = inputs.size();
std::vector<mkldnn::memory::primitive_desc> data_md;
std::vector<const mkldnn::memory *> data_mem;
std::vector<float> scales(num_inputs, 1);
std::vector<NDArray> in_bufs(num_inputs);

data_md.reserve(num_inputs);
data_mem.reserve(num_inputs);

for (index_t i = 0; i < static_cast<index_t>(num_inputs); ++i) {
const mkldnn::memory *in_mem;
if (inputs[i].IsMKLDNNData() && inputs[i].IsView()) {
in_bufs[i] = inputs[i].Reorder2Default();
in_mem = in_bufs[i].GetMKLDNNData();
} else {
in_bufs[i] = inputs[i];
in_mem = inputs[i].GetMKLDNNData();
}
in_prims.push_back(*in_mem);
in_pds[i] = in_mem->get_primitive_desc();
mkldnn::memory::primitive_desc tmp_pd = in_mem->get_primitive_desc();
data_md.push_back(tmp_pd);
data_mem.push_back(in_mem);
}

mkldnn::sum::primitive_desc pdesc(scales, in_pds);
auto mem = CreateMKLDNNMem(out_data, pdesc.dst_primitive_desc(), req, &inputs[0]);
MKLDNNStream *stream = MKLDNNStream::Get();
stream->RegisterPrim(mkldnn::sum(pdesc, in_prims, *mem.second));
CommitOutput(out_data, mem);
stream->Submit();
MKLDNNSumFwd &fwd = GetSumForward(scales, in_bufs, data_md);
mxnet::mkldnn_output_t out_mem = CreateMKLDNNMem(out_data,
fwd.fwd_pd.dst_primitive_desc(),
req,
&in_bufs[0]);
fwd.SetNewMem(data_mem, *out_mem.second);
MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());
CommitOutput(out_data, out_mem);
MKLDNNStream::Get()->Submit();
}

} // namespace op
Expand Down
8 changes: 7 additions & 1 deletion src/operator/tensor/elemwise_binary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@
namespace mxnet {
namespace op {

bool SupportMKLDNNSum(const NDArray& input) {
int ndim = input.shape().ndim();
return input.dtype() == mshadow::kFloat32 && (ndim >= 1 && ndim <= 4) &&
input.storage_type() == kDefaultStorage;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't put this function into mkldnn_base-inl.h?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, some MKLDNN ops supports ndim [1, 4], while some ops still doesn't support ndim=3.
For now, there would be several such support check functions for different ops.
But indeed, we can combine all the similar function when all the ops are finalized.


static void ElemwiseAddEx(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
Expand All @@ -38,7 +44,7 @@ static void ElemwiseAddEx(const nnvm::NodeAttrs& attrs,
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
#if MXNET_USE_MKLDNN == 1
if (SupportMKLDNN(inputs[0]) && SupportMKLDNN(inputs[1])) {
if (SupportMKLDNNSum(inputs[0]) && SupportMKLDNNSum(inputs[1])) {
MKLDNNSumForward(attrs, ctx, inputs, req[0], outputs[0]);
return;
} else if (inputs[0].storage_type() == kDefaultStorage
Expand Down