Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Add primitive cache for mkldnn sum
Browse files Browse the repository at this point in the history
  • Loading branch information
ciyongch committed May 5, 2019
1 parent 5ba285b commit ddb0c09
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 19 deletions.
104 changes: 86 additions & 18 deletions src/operator/nn/mkldnn/mkldnn_sum.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
*/
#include <iostream>

#include "../../operator_common.h"
#include "./mkldnn_ops-inl.h"
#include "./mkldnn_base-inl.h"

Expand Down Expand Up @@ -58,37 +59,104 @@ void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2,
MKLDNNStream::Get()->RegisterPrim(mkldnn::sum(sum_pd, inputs, out));
}

class MKLDNNSumFwd {
public:
mkldnn::sum::primitive_desc fwd_pd;

MKLDNNSumFwd(const std::vector<float> &scales,
const std::vector<mkldnn::memory::primitive_desc> &data_md)
: fwd_pd(scales, data_md) {
data_.resize(data_md.size());
}

void SetNewMem(const std::vector<const mkldnn::memory *> &in_data, const mkldnn::memory &output);

const mkldnn::sum &GetFwd() const { return *fwd_; }

private:
std::shared_ptr<mkldnn::sum> fwd_;
std::vector<std::shared_ptr<mkldnn::memory>> data_;
std::vector<mkldnn::primitive::at> data_mem_;
std::shared_ptr<mkldnn::memory> out_;
};

static MKLDNNSumFwd &GetSumForward(
const std::vector<float> &scales, const std::vector<NDArray> &in_data,
const std::vector<mkldnn::memory::primitive_desc> &data_md) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<OpSignature, MKLDNNSumFwd, OpHash> fwds;
#else
static MX_THREAD_LOCAL std::unordered_map<OpSignature, MKLDNNSumFwd, OpHash> fwds;
#endif
OpSignature key;
key.AddSign(in_data);

auto it = fwds.find(key);
if (it == fwds.end()) {
MKLDNNSumFwd fwd(scales, data_md);
it = AddToCache(&fwds, key, fwd);
}
return it->second;
}

void MKLDNNSumFwd::SetNewMem(const std::vector<const mkldnn::memory *> &in_data,
const mkldnn::memory &output) {
auto num_inputs = data_.size();
CHECK_EQ(in_data.size(), num_inputs);
for (index_t i = 0; i < static_cast<index_t>(num_inputs); ++i) {
if (this->data_[i] == nullptr) {
this->data_[i] = std::shared_ptr<mkldnn::memory>(
new mkldnn::memory(in_data[i]->get_primitive_desc(), in_data[i]->get_data_handle()));
this->data_mem_.push_back(*this->data_[i]);
} else {
this->data_[i]->set_data_handle(in_data[i]->get_data_handle());
}
}
if (this->out_ == nullptr)
this->out_ = std::shared_ptr<mkldnn::memory>(
new mkldnn::memory(fwd_pd.dst_primitive_desc(), output.get_data_handle()));
else
this->out_->set_data_handle(output.get_data_handle());

if (this->fwd_ == nullptr)
this->fwd_.reset(new mkldnn::sum(fwd_pd, this->data_mem_, *this->out_));
}

void MKLDNNSumForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &inputs, const OpReqType &req,
const NDArray &out_data) {
if (req == kNullOp) {
return;
}

TmpMemMgr::Get()->Init(ctx.requested[0]);
std::vector<mkldnn::primitive::at> in_prims;
std::vector<mkldnn::memory::primitive_desc> in_pds(inputs.size());
std::vector<float> scales(inputs.size(), 1);
in_prims.reserve(inputs.size());
std::vector<NDArray> in_bufs(inputs.size());
for (size_t i = 0; i < inputs.size(); i++) {
auto num_inputs = inputs.size();
std::vector<mkldnn::memory::primitive_desc> data_md;
std::vector<const mkldnn::memory *> data_mem;
std::vector<float> scales(num_inputs, 1);
std::vector<NDArray> in_bufs(num_inputs);

data_md.reserve(num_inputs);
data_mem.reserve(num_inputs);

for (index_t i = 0; i < static_cast<index_t>(num_inputs); ++i) {
const mkldnn::memory *in_mem;
if (inputs[i].IsMKLDNNData() && inputs[i].IsView()) {
in_bufs[i] = inputs[i].Reorder2Default();
in_mem = in_bufs[i].GetMKLDNNData();
} else {
in_mem = inputs[i].GetMKLDNNData();
}
in_prims.push_back(*in_mem);
in_pds[i] = in_mem->get_primitive_desc();
mkldnn::memory::primitive_desc tmp_pd = in_mem->get_primitive_desc();
data_md.push_back(tmp_pd);
data_mem.push_back(in_mem);
}

mkldnn::sum::primitive_desc pdesc(scales, in_pds);
auto mem = CreateMKLDNNMem(out_data, pdesc.dst_primitive_desc(), req, &inputs[0]);
MKLDNNStream *stream = MKLDNNStream::Get();
stream->RegisterPrim(mkldnn::sum(pdesc, in_prims, *mem.second));
CommitOutput(out_data, mem);
stream->Submit();
MKLDNNSumFwd &fwd = GetSumForward(scales, inputs, data_md);
mxnet::mkldnn_output_t out_mem = CreateMKLDNNMem(out_data,
fwd.fwd_pd.dst_primitive_desc(),
req,
&inputs[0]);
fwd.SetNewMem(data_mem, *out_mem.second);
MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());
CommitOutput(out_data, out_mem);
MKLDNNStream::Get()->Submit();
}

} // namespace op
Expand Down
8 changes: 7 additions & 1 deletion src/operator/tensor/elemwise_binary_op_basic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@
namespace mxnet {
namespace op {

bool SupportMKLDNNSum(const NDArray& input) {
int ndim = input.shape().ndim();
return input.dtype() == mshadow::kFloat32 && (ndim >= 1 && ndim <= 4) &&
input.storage_type() == kDefaultStorage;
}

static void ElemwiseAddEx(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
Expand All @@ -38,7 +44,7 @@ static void ElemwiseAddEx(const nnvm::NodeAttrs& attrs,
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
#if MXNET_USE_MKLDNN == 1
if (SupportMKLDNN(inputs[0]) && SupportMKLDNN(inputs[1])) {
if (SupportMKLDNNSum(inputs[0]) && SupportMKLDNNSum(inputs[1])) {
MKLDNNSumForward(attrs, ctx, inputs, req[0], outputs[0]);
return;
} else if (inputs[0].storage_type() == kDefaultStorage
Expand Down

0 comments on commit ddb0c09

Please sign in to comment.