Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
MKLDNN Forward FullyConnected op cache (#11611)
Browse files Browse the repository at this point in the history
* Enable primitive allocation cache for FullyConnected

* Enable primitive allocation cache for FullyConnected

* fix indent and pass in_data as last argument for CreateMKLDNNMem

* fix indent and pass in_data as last argument for CreateMKLDNNMem
  • Loading branch information
huangzhiyuan authored and eric-haibin-lin committed Aug 26, 2018
1 parent 54d5777 commit 7230bb9
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 12 deletions.
17 changes: 17 additions & 0 deletions src/operator/nn/fully_connected-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ struct FullyConnectedParam : public dmlc::Parameter<FullyConnectedParam> {
DMLC_DECLARE_FIELD(flatten).set_default(true)
.describe("Whether to collapse all but the first axis of the input data tensor.");
}
bool operator==(const FullyConnectedParam& other) const {
return this->num_hidden == other.num_hidden &&
this->no_bias == other.no_bias &&
this->flatten == other.flatten;
}
};

template<typename xpu, typename DType>
Expand Down Expand Up @@ -228,4 +233,16 @@ void FullyConnectedGradCompute(const nnvm::NodeAttrs& attrs,

} // namespace op
} // namespace mxnet
namespace std {
template<>
struct hash<mxnet::op::FullyConnectedParam> {
size_t operator()(const mxnet::op::FullyConnectedParam& val) {
size_t ret = 0;
ret = dmlc::HashCombine(ret, val.num_hidden);
ret = dmlc::HashCombine(ret, val.no_bias);
ret = dmlc::HashCombine(ret, val.flatten);
return ret;
}
};
} // namespace std
#endif // MXNET_OPERATOR_NN_FULLY_CONNECTED_INL_H_
118 changes: 106 additions & 12 deletions src/operator/nn/mkldnn/mkldnn_fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,100 @@ inline static mkldnn::inner_product_backward_weights::primitive_desc GetIPBwdWei
}
}

class MKLDNNFullyConnectForward {
std::shared_ptr<mkldnn::memory> data;
std::shared_ptr<mkldnn::memory> weight;
std::shared_ptr<mkldnn::memory> out;
std::shared_ptr<mkldnn::memory> bias;
std::shared_ptr<mkldnn::inner_product_forward> ipFwd;

public:
mkldnn::inner_product_forward::primitive_desc ipFwd_pd;

MKLDNNFullyConnectForward(const FullyConnectedParam &param, bool is_train,
const NDArray &data, const NDArray &weight,
const NDArray *bias,
const mkldnn::memory::desc &output)
: ipFwd_pd(GetIPFwd(data, weight, bias, output, is_train)) {}

void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight,
const mkldnn::memory *bias, const mkldnn::memory &output) {
if (this->data == nullptr)
this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
ipFwd_pd.src_primitive_desc(), data.get_data_handle()));
else
this->data->set_data_handle(data.get_data_handle());

if (this->weight == nullptr)
this->weight = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
ipFwd_pd.weights_primitive_desc(), weight.get_data_handle()));
else
this->weight->set_data_handle(weight.get_data_handle());

if (this->out == nullptr)
this->out = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
ipFwd_pd.dst_primitive_desc(), output.get_data_handle()));
else
this->out->set_data_handle(output.get_data_handle());

if (bias != nullptr) {
if (this->bias == nullptr)
this->bias = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
ipFwd_pd.bias_primitive_desc(), bias->get_data_handle()));
else
this->bias->set_data_handle(bias->get_data_handle());
if (this->ipFwd == nullptr)
this->ipFwd = std::shared_ptr<mkldnn::inner_product_forward>(
new mkldnn::inner_product_forward(
ipFwd_pd, mkldnn::primitive::at(*this->data),
mkldnn::primitive::at(*this->weight),
mkldnn::primitive::at(*this->bias), *this->out));
} else if (this->ipFwd == nullptr) {
this->ipFwd = std::shared_ptr<mkldnn::inner_product_forward>(
new mkldnn::inner_product_forward(
ipFwd_pd, mkldnn::primitive::at(*this->data),
mkldnn::primitive::at(*this->weight), *this->out));
}
}
const mkldnn::inner_product_forward &GetIpFwd() const {
return *ipFwd;
}
};

typedef ParamOpSign<FullyConnectedParam> MKLDNNFullyconSignature;

static inline MKLDNNFullyConnectForward &GetFCFwd(
const nnvm::NodeAttrs &attrs, const NDArray &data, const NDArray &weight,
const NDArray *bias, const mkldnn::memory::desc &output,
const bool is_train) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<MKLDNNFullyconSignature,
MKLDNNFullyConnectForward, OpHash> fcFwds;
#else
static MX_THREAD_LOCAL std::unordered_map<MKLDNNFullyconSignature,
MKLDNNFullyConnectForward, OpHash> fcFwds;
#endif
const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
MKLDNNFullyconSignature key(param);
key.AddSign(data);
key.AddSign(weight);
key.AddSign(is_train);

if (bias)
key.AddSign(*bias);

auto it = fcFwds.find(key);
if (it == fcFwds.end()) {
MKLDNNFullyConnectForward fcFwd(param, is_train, data, weight, bias,
output);
auto ins_ret = fcFwds.insert(
std::pair<MKLDNNFullyconSignature, MKLDNNFullyConnectForward>(key, fcFwd));
CHECK(ins_ret.second);
it = ins_ret.first;
}
return it->second;
}

void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
Expand Down Expand Up @@ -112,21 +206,21 @@ void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data[fullc::kOut].dtype()),
mkldnn::memory::format::any);
}

mkldnn::inner_product_forward::primitive_desc ipFwd_pd = GetIPFwd(data, weight,
param.no_bias ? nullptr : &in_data[fullc::kBias], out_md, ctx.is_train);
auto data_mem = data.GetMKLDNNDataReorder(ipFwd_pd.src_primitive_desc());
auto weight_mem = weight.GetMKLDNNDataReorder(ipFwd_pd.weights_primitive_desc());
MKLDNNFullyConnectForward &FCFwd =
GetFCFwd(attrs, data, weight, param.no_bias ? nullptr : &in_data[fullc::kBias],
out_md, ctx.is_train);
auto data_mem = data.GetMKLDNNDataReorder(FCFwd.ipFwd_pd.src_primitive_desc());
auto weight_mem = weight.GetMKLDNNDataReorder(FCFwd.ipFwd_pd.weights_primitive_desc());
auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut],
ipFwd_pd.dst_primitive_desc(), req[fullc::kOut]);
if (param.no_bias) {
MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_forward(
ipFwd_pd, *data_mem, *weight_mem, *out_mem.second));
FCFwd.ipFwd_pd.dst_primitive_desc(), req[fullc::kOut], &data);
if (!param.no_bias) {
auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder(
FCFwd.ipFwd_pd.bias_primitive_desc());
FCFwd.SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second);
} else {
auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder(ipFwd_pd.bias_primitive_desc());
MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_forward(ipFwd_pd,
*data_mem, *weight_mem, *bias_mem, *out_mem.second));
FCFwd.SetNewMem(*data_mem, *weight_mem, nullptr, *out_mem.second);
}
MKLDNNStream::Get()->RegisterPrim(FCFwd.GetIpFwd());
CommitOutput(out_data[fullc::kOut], out_mem);
MKLDNNStream::Get()->Submit();
}
Expand Down

0 comments on commit 7230bb9

Please sign in to comment.