diff --git a/src/operator/nn/fully_connected-inl.h b/src/operator/nn/fully_connected-inl.h index 7eba2e20e573..bff582661189 100644 --- a/src/operator/nn/fully_connected-inl.h +++ b/src/operator/nn/fully_connected-inl.h @@ -60,6 +60,11 @@ struct FullyConnectedParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(flatten).set_default(true) .describe("Whether to collapse all but the first axis of the input data tensor."); } + bool operator==(const FullyConnectedParam& other) const { + return this->num_hidden == other.num_hidden && + this->no_bias == other.no_bias && + this->flatten == other.flatten; + } }; template @@ -227,4 +232,16 @@ void FullyConnectedGradCompute(const nnvm::NodeAttrs& attrs, } // namespace op } // namespace mxnet +namespace std { +template<> +struct hash { + size_t operator()(const mxnet::op::FullyConnectedParam& val) { + size_t ret = 0; + ret = dmlc::HashCombine(ret, val.num_hidden); + ret = dmlc::HashCombine(ret, val.no_bias); + ret = dmlc::HashCombine(ret, val.flatten); + return ret; + } +}; +} // namespace std #endif // MXNET_OPERATOR_NN_FULLY_CONNECTED_INL_H_ diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc index f86f8dbefa2b..5f672cd51fd5 100644 --- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc +++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc @@ -82,6 +82,100 @@ inline static mkldnn::inner_product_backward_weights::primitive_desc GetIPBwdWei } } +class MKLDNNFullyConnectForward { + std::shared_ptr data; + std::shared_ptr weight; + std::shared_ptr out; + std::shared_ptr bias; + std::shared_ptr ipFwd; + + public: + mkldnn::inner_product_forward::primitive_desc ipFwd_pd; + + MKLDNNFullyConnectForward(const FullyConnectedParam ¶m, bool is_train, + const NDArray &data, const NDArray &weight, + const NDArray *bias, + const mkldnn::memory::desc &output) + : ipFwd_pd(GetIPFwd(data, weight, bias, output, is_train)) {} + + void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight, + const mkldnn::memory *bias, const mkldnn::memory &output) { + if (this->data == nullptr) + this->data = std::shared_ptr(new mkldnn::memory( + ipFwd_pd.src_primitive_desc(), data.get_data_handle())); + else + this->data->set_data_handle(data.get_data_handle()); + + if (this->weight == nullptr) + this->weight = std::shared_ptr(new mkldnn::memory( + ipFwd_pd.weights_primitive_desc(), weight.get_data_handle())); + else + this->weight->set_data_handle(weight.get_data_handle()); + + if (this->out == nullptr) + this->out = std::shared_ptr(new mkldnn::memory( + ipFwd_pd.dst_primitive_desc(), output.get_data_handle())); + else + this->out->set_data_handle(output.get_data_handle()); + + if (bias != nullptr) { + if (this->bias == nullptr) + this->bias = std::shared_ptr(new mkldnn::memory( + ipFwd_pd.bias_primitive_desc(), bias->get_data_handle())); + else + this->bias->set_data_handle(bias->get_data_handle()); + if (this->ipFwd == nullptr) + this->ipFwd = std::shared_ptr( + new mkldnn::inner_product_forward( + ipFwd_pd, mkldnn::primitive::at(*this->data), + mkldnn::primitive::at(*this->weight), + mkldnn::primitive::at(*this->bias), *this->out)); + } else if (this->ipFwd == nullptr) { + this->ipFwd = std::shared_ptr( + new mkldnn::inner_product_forward( + ipFwd_pd, mkldnn::primitive::at(*this->data), + mkldnn::primitive::at(*this->weight), *this->out)); + } + } + const mkldnn::inner_product_forward &GetIpFwd() const { + return *ipFwd; + } +}; + +typedef ParamOpSign MKLDNNFullyconSignature; + +static inline MKLDNNFullyConnectForward &GetFCFwd( + const nnvm::NodeAttrs &attrs, const NDArray &data, const NDArray &weight, + const NDArray *bias, const mkldnn::memory::desc &output, + const bool is_train) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local std::unordered_map fcFwds; +#else + static MX_THREAD_LOCAL std::unordered_map fcFwds; +#endif + const FullyConnectedParam& param = nnvm::get(attrs.parsed); + MKLDNNFullyconSignature key(param); + key.AddSign(data); + key.AddSign(weight); + key.AddSign(is_train); + + if (bias) + key.AddSign(*bias); + + auto it = fcFwds.find(key); + if (it == fcFwds.end()) { + MKLDNNFullyConnectForward fcFwd(param, is_train, data, weight, bias, + output); + auto ins_ret = fcFwds.insert( + std::pair(key, fcFwd)); + CHECK(ins_ret.second); + it = ins_ret.first; + } + return it->second; +} + void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &in_data, const std::vector &req, @@ -112,21 +206,21 @@ void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data[fullc::kOut].dtype()), mkldnn::memory::format::any); } - - mkldnn::inner_product_forward::primitive_desc ipFwd_pd = GetIPFwd(data, weight, - param.no_bias ? nullptr : &in_data[fullc::kBias], out_md, ctx.is_train); - auto data_mem = data.GetMKLDNNDataReorder(ipFwd_pd.src_primitive_desc()); - auto weight_mem = weight.GetMKLDNNDataReorder(ipFwd_pd.weights_primitive_desc()); + MKLDNNFullyConnectForward &FCFwd = + GetFCFwd(attrs, data, weight, param.no_bias ? nullptr : &in_data[fullc::kBias], + out_md, ctx.is_train); + auto data_mem = data.GetMKLDNNDataReorder(FCFwd.ipFwd_pd.src_primitive_desc()); + auto weight_mem = weight.GetMKLDNNDataReorder(FCFwd.ipFwd_pd.weights_primitive_desc()); auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut], - ipFwd_pd.dst_primitive_desc(), req[fullc::kOut]); - if (param.no_bias) { - MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_forward( - ipFwd_pd, *data_mem, *weight_mem, *out_mem.second)); + FCFwd.ipFwd_pd.dst_primitive_desc(), req[fullc::kOut], &data); + if (!param.no_bias) { + auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder( + FCFwd.ipFwd_pd.bias_primitive_desc()); + FCFwd.SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second); } else { - auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder(ipFwd_pd.bias_primitive_desc()); - MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_forward(ipFwd_pd, - *data_mem, *weight_mem, *bias_mem, *out_mem.second)); + FCFwd.SetNewMem(*data_mem, *weight_mem, nullptr, *out_mem.second); } + MKLDNNStream::Get()->RegisterPrim(FCFwd.GetIpFwd()); CommitOutput(out_data[fullc::kOut], out_mem); MKLDNNStream::Get()->Submit(); }