Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

MKLDNN Forward FullyConnected op cache #11611

Merged
merged 5 commits into from
Aug 26, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/operator/nn/fully_connected-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ struct FullyConnectedParam : public dmlc::Parameter<FullyConnectedParam> {
DMLC_DECLARE_FIELD(flatten).set_default(true)
.describe("Whether to collapse all but the first axis of the input data tensor.");
}
bool operator==(const FullyConnectedParam& other) const {
return this->num_hidden == other.num_hidden &&
this->no_bias == other.no_bias &&
this->flatten == other.flatten;
}
};

template<typename xpu, typename DType>
Expand Down Expand Up @@ -227,4 +232,16 @@ void FullyConnectedGradCompute(const nnvm::NodeAttrs& attrs,

} // namespace op
} // namespace mxnet
namespace std {
template<>
struct hash<mxnet::op::FullyConnectedParam> {
size_t operator()(const mxnet::op::FullyConnectedParam& val) {
size_t ret = 0;
ret = dmlc::HashCombine(ret, val.num_hidden);
ret = dmlc::HashCombine(ret, val.no_bias);
ret = dmlc::HashCombine(ret, val.flatten);
return ret;
}
};
} // namespace std
#endif // MXNET_OPERATOR_NN_FULLY_CONNECTED_INL_H_
118 changes: 106 additions & 12 deletions src/operator/nn/mkldnn/mkldnn_fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,100 @@ inline static mkldnn::inner_product_backward_weights::primitive_desc GetIPBwdWei
}
}

class MKLDNNFullyConnectForward {
std::shared_ptr<mkldnn::memory> data;
std::shared_ptr<mkldnn::memory> weight;
std::shared_ptr<mkldnn::memory> out;
std::shared_ptr<mkldnn::memory> bias;
std::shared_ptr<mkldnn::inner_product_forward> ipFwd;

public:
mkldnn::inner_product_forward::primitive_desc ipFwd_pd;

MKLDNNFullyConnectForward(const FullyConnectedParam &param, bool is_train,
const NDArray &data, const NDArray &weight,
const NDArray *bias,
const mkldnn::memory::desc &output)
: ipFwd_pd(GetIPFwd(data, weight, bias, output, is_train)) {}

void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight,
const mkldnn::memory *bias, const mkldnn::memory &output) {
if (this->data == nullptr)
this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
ipFwd_pd.src_primitive_desc(), data.get_data_handle()));
else
this->data->set_data_handle(data.get_data_handle());

if (this->weight == nullptr)
this->weight = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
ipFwd_pd.weights_primitive_desc(), weight.get_data_handle()));
else
this->weight->set_data_handle(weight.get_data_handle());

if (this->out == nullptr)
this->out = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
ipFwd_pd.dst_primitive_desc(), output.get_data_handle()));
else
this->out->set_data_handle(output.get_data_handle());

if (bias != nullptr) {
if (this->bias == nullptr)
this->bias = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
ipFwd_pd.bias_primitive_desc(), bias->get_data_handle()));
else
this->bias->set_data_handle(bias->get_data_handle());
if (this->ipFwd == nullptr)
this->ipFwd = std::shared_ptr<mkldnn::inner_product_forward>(
new mkldnn::inner_product_forward(
ipFwd_pd, mkldnn::primitive::at(*this->data),
mkldnn::primitive::at(*this->weight),
mkldnn::primitive::at(*this->bias), *this->out));
} else if (this->ipFwd == nullptr) {
this->ipFwd = std::shared_ptr<mkldnn::inner_product_forward>(
new mkldnn::inner_product_forward(
ipFwd_pd, mkldnn::primitive::at(*this->data),
mkldnn::primitive::at(*this->weight), *this->out));
}
}
const mkldnn::inner_product_forward &GetIpFwd() const {
return *ipFwd;
}
};

typedef ParamOpSign<FullyConnectedParam> MKLDNNFullyconSignature;

static inline MKLDNNFullyConnectForward &GetFCFwd(
const nnvm::NodeAttrs &attrs, const NDArray &data, const NDArray &weight,
const NDArray *bias, const mkldnn::memory::desc &output,
const bool is_train) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<MKLDNNFullyconSignature,
MKLDNNFullyConnectForward, OpHash> fcFwds;
#else
static MX_THREAD_LOCAL std::unordered_map<MKLDNNFullyconSignature,
MKLDNNFullyConnectForward, OpHash> fcFwds;
#endif
const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
MKLDNNFullyconSignature key(param);
key.AddSign(data);
key.AddSign(weight);
key.AddSign(is_train);

if (bias)
key.AddSign(*bias);

auto it = fcFwds.find(key);
if (it == fcFwds.end()) {
MKLDNNFullyConnectForward fcFwd(param, is_train, data, weight, bias,
output);
auto ins_ret = fcFwds.insert(
std::pair<MKLDNNFullyconSignature, MKLDNNFullyConnectForward>(key, fcFwd));
CHECK(ins_ret.second);
it = ins_ret.first;
}
return it->second;
}

void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
Expand Down Expand Up @@ -112,21 +206,21 @@ void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data[fullc::kOut].dtype()),
mkldnn::memory::format::any);
}

mkldnn::inner_product_forward::primitive_desc ipFwd_pd = GetIPFwd(data, weight,
param.no_bias ? nullptr : &in_data[fullc::kBias], out_md, ctx.is_train);
auto data_mem = data.GetMKLDNNDataReorder(ipFwd_pd.src_primitive_desc());
auto weight_mem = weight.GetMKLDNNDataReorder(ipFwd_pd.weights_primitive_desc());
MKLDNNFullyConnectForward &FCFwd =
GetFCFwd(attrs, data, weight, param.no_bias ? nullptr : &in_data[fullc::kBias],
out_md, ctx.is_train);
auto data_mem = data.GetMKLDNNDataReorder(FCFwd.ipFwd_pd.src_primitive_desc());
auto weight_mem = weight.GetMKLDNNDataReorder(FCFwd.ipFwd_pd.weights_primitive_desc());
auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut],
ipFwd_pd.dst_primitive_desc(), req[fullc::kOut]);
if (param.no_bias) {
MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_forward(
ipFwd_pd, *data_mem, *weight_mem, *out_mem.second));
FCFwd.ipFwd_pd.dst_primitive_desc(), req[fullc::kOut], &data);
if (!param.no_bias) {
auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder(
FCFwd.ipFwd_pd.bias_primitive_desc());
FCFwd.SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second);
} else {
auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder(ipFwd_pd.bias_primitive_desc());
MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_forward(ipFwd_pd,
*data_mem, *weight_mem, *bias_mem, *out_mem.second));
FCFwd.SetNewMem(*data_mem, *weight_mem, nullptr, *out_mem.second);
}
MKLDNNStream::Get()->RegisterPrim(FCFwd.GetIpFwd());
CommitOutput(out_data[fullc::kOut], out_mem);
MKLDNNStream::Get()->Submit();
}
Expand Down