diff --git a/src/operator/nn/fully_connected-inl.h b/src/operator/nn/fully_connected-inl.h index 7eba2e20e573..bff582661189 100644 --- a/src/operator/nn/fully_connected-inl.h +++ b/src/operator/nn/fully_connected-inl.h @@ -60,6 +60,11 @@ struct FullyConnectedParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(flatten).set_default(true) .describe("Whether to collapse all but the first axis of the input data tensor."); } + bool operator==(const FullyConnectedParam& other) const { + return this->num_hidden == other.num_hidden && + this->no_bias == other.no_bias && + this->flatten == other.flatten; + } }; template @@ -227,4 +232,16 @@ void FullyConnectedGradCompute(const nnvm::NodeAttrs& attrs, } // namespace op } // namespace mxnet +namespace std { +template<> +struct hash { + size_t operator()(const mxnet::op::FullyConnectedParam& val) { + size_t ret = 0; + ret = dmlc::HashCombine(ret, val.num_hidden); + ret = dmlc::HashCombine(ret, val.no_bias); + ret = dmlc::HashCombine(ret, val.flatten); + return ret; + } +}; +} // namespace std #endif // MXNET_OPERATOR_NN_FULLY_CONNECTED_INL_H_ diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc index f86f8dbefa2b..12e75612fe06 100644 --- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc +++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc @@ -82,6 +82,101 @@ inline static mkldnn::inner_product_backward_weights::primitive_desc GetIPBwdWei } } +class MKLDNNFullyConnectForward { + std::shared_ptr data; + std::shared_ptr weight; + std::shared_ptr out; + std::shared_ptr bias; + std::shared_ptr ipFwd; + + public: + mkldnn::inner_product_forward::primitive_desc ipFwd_pd; + + MKLDNNFullyConnectForward(const FullyConnectedParam ¶m, bool is_train, + const NDArray &data, const NDArray &weight, + const NDArray *bias, + const mkldnn::memory::desc &output) + : ipFwd_pd(GetIPFwd(data, weight, bias, output, is_train)) {} + + void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight, + const mkldnn::memory *bias, const mkldnn::memory &output) { + if (this->data == nullptr) + this->data = std::shared_ptr(new mkldnn::memory( + ipFwd_pd.src_primitive_desc(), data.get_data_handle())); + else + this->data->set_data_handle(data.get_data_handle()); + + if (this->weight == nullptr) + this->weight = std::shared_ptr(new mkldnn::memory( + ipFwd_pd.weights_primitive_desc(), weight.get_data_handle())); + else + this->weight->set_data_handle(weight.get_data_handle()); + + if (this->out == nullptr) + this->out = std::shared_ptr(new mkldnn::memory( + ipFwd_pd.dst_primitive_desc(), output.get_data_handle())); + else + this->out->set_data_handle(output.get_data_handle()); + + if (bias != nullptr) { + if (this->bias == nullptr) + this->bias = std::shared_ptr(new mkldnn::memory( + ipFwd_pd.bias_primitive_desc(), bias->get_data_handle())); + else + this->bias->set_data_handle(bias->get_data_handle()); + if (this->ipFwd == nullptr) + this->ipFwd = std::shared_ptr( + new mkldnn::inner_product_forward( + ipFwd_pd, mkldnn::primitive::at(*this->data), + mkldnn::primitive::at(*this->weight), + mkldnn::primitive::at(*this->bias), *this->out)); + } else if (this->ipFwd == nullptr) { + this->ipFwd = std::shared_ptr( + new mkldnn::inner_product_forward( + ipFwd_pd, mkldnn::primitive::at(*this->data), + mkldnn::primitive::at(*this->weight), *this->out)); + } + } + + const mkldnn::inner_product_forward &GetIpFwd() const { + return *ipFwd; + } +}; + +typedef ParamOpSign MKLDNNFullyconSignature; + +static inline MKLDNNFullyConnectForward &GetFCFwd( + const nnvm::NodeAttrs &attrs, const NDArray &data, const NDArray &weight, + const NDArray *bias, const mkldnn::memory::desc &output, + const bool is_train) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local std::unordered_map fcFwds; +#else + static MX_THREAD_LOCAL std::unordered_map fcFwds; +#endif + const FullyConnectedParam& param = nnvm::get(attrs.parsed); + MKLDNNFullyconSignature key(param); + key.AddSign(data); + key.AddSign(weight); + key.AddSign(is_train); + + if (bias) + key.AddSign(*bias); + + auto it = fcFwds.find(key); + if (it == fcFwds.end()) { + MKLDNNFullyConnectForward fcFwd(param, is_train, data, weight, bias, + output); + auto ins_ret = fcFwds.insert( + std::pair(key, fcFwd)); + CHECK(ins_ret.second); + it = ins_ret.first; + } + return it->second; +} + void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &in_data, const std::vector &req, @@ -112,25 +207,168 @@ void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data[fullc::kOut].dtype()), mkldnn::memory::format::any); } - - mkldnn::inner_product_forward::primitive_desc ipFwd_pd = GetIPFwd(data, weight, - param.no_bias ? nullptr : &in_data[fullc::kBias], out_md, ctx.is_train); - auto data_mem = data.GetMKLDNNDataReorder(ipFwd_pd.src_primitive_desc()); - auto weight_mem = weight.GetMKLDNNDataReorder(ipFwd_pd.weights_primitive_desc()); + MKLDNNFullyConnectForward &FCFwd = + GetFCFwd(attrs, data, weight, param.no_bias ? nullptr : &in_data[fullc::kBias], + out_md, ctx.is_train); + auto data_mem = data.GetMKLDNNDataReorder(FCFwd.ipFwd_pd.src_primitive_desc()); + auto weight_mem = weight.GetMKLDNNDataReorder(FCFwd.ipFwd_pd.weights_primitive_desc()); auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut], - ipFwd_pd.dst_primitive_desc(), req[fullc::kOut]); - if (param.no_bias) { - MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_forward( - ipFwd_pd, *data_mem, *weight_mem, *out_mem.second)); + FCFwd.ipFwd_pd.dst_primitive_desc(), req[fullc::kOut]); + if (!param.no_bias) { + auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder( + FCFwd.ipFwd_pd.bias_primitive_desc()); + FCFwd.SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second); } else { - auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder(ipFwd_pd.bias_primitive_desc()); - MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_forward(ipFwd_pd, - *data_mem, *weight_mem, *bias_mem, *out_mem.second)); + FCFwd.SetNewMem(*data_mem, *weight_mem, nullptr, *out_mem.second); } + MKLDNNStream::Get()->RegisterPrim(FCFwd.GetIpFwd()); CommitOutput(out_data[fullc::kOut], out_mem); MKLDNNStream::Get()->Submit(); } +class MKLDNNFullyConnectBackward { + std::shared_ptr ipBwdData; + std::shared_ptr ipBwdWeight; + std::shared_ptr out_grad; + std::shared_ptr weight; + std::shared_ptr in_grad; + std::shared_ptr data; + std::shared_ptr in_grad_weight; + std::shared_ptr in_grad_bias; + + public: + mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd; + mkldnn::inner_product_backward_weights::primitive_desc ipBwdWeights_pd; + + public: + MKLDNNFullyConnectBackward( + const FullyConnectedParam ¶m, const NDArray &data, + const NDArray &weight, const std::vector &in_grad, + const NDArray &out_grad, const std::vector &req, + const mkldnn::inner_product_forward::primitive_desc &ipFwd_pd) + : ipBwdData_pd(GetIpBwdData(data, weight, out_grad, ipFwd_pd)), + ipBwdWeights_pd(GetIPBwdWeights( + data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias], + out_grad, ipFwd_pd)) {} + + void SetNewMemData(const mkldnn::memory &out_grad, + const mkldnn::memory &weight, + const mkldnn::memory &in_grad) { + if (this->out_grad == nullptr) + this->out_grad = std::shared_ptr(new mkldnn::memory( + ipBwdData_pd.diff_dst_primitive_desc(), out_grad.get_data_handle())); + else + this->out_grad->set_data_handle(out_grad.get_data_handle()); + + if (this->weight == nullptr) + this->weight = std::shared_ptr(new mkldnn::memory( + ipBwdData_pd.weights_primitive_desc(), weight.get_data_handle())); + else + this->weight->set_data_handle(weight.get_data_handle()); + + if (this->in_grad == nullptr) + this->in_grad = std::shared_ptr(new mkldnn::memory( + ipBwdData_pd.diff_src_primitive_desc(), in_grad.get_data_handle())); + else + this->in_grad->set_data_handle(in_grad.get_data_handle()); + + if (this->ipBwdData == nullptr) + this->ipBwdData = std::shared_ptr( + new mkldnn::inner_product_backward_data( + this->ipBwdData_pd, mkldnn::primitive::at(*this->out_grad), + mkldnn::primitive::at(*this->weight), *this->in_grad)); + } + + void SetNewWeightMem(const FullyConnectedParam ¶m, + const mkldnn::memory &data, + const mkldnn::memory &out_grad, + const mkldnn::memory &in_grad_weight, + const mkldnn::memory &in_grad_bias) { + if (this->data == nullptr) + this->data = std::shared_ptr(new mkldnn::memory( + ipBwdWeights_pd.src_primitive_desc(), data.get_data_handle())); + else + this->data->set_data_handle(data.get_data_handle()); + + if (this->out_grad == nullptr) + this->out_grad = std::shared_ptr(new mkldnn::memory( + ipBwdWeights_pd.diff_dst_primitive_desc(), out_grad.get_data_handle())); + else + this->out_grad->set_data_handle(out_grad.get_data_handle()); + + if (this->in_grad_weight == nullptr) + this->in_grad_weight = std::shared_ptr( + new mkldnn::memory(ipBwdWeights_pd.diff_weights_primitive_desc(), + in_grad_weight.get_data_handle())); + else + this->in_grad_weight->set_data_handle(in_grad_weight.get_data_handle()); + + if (!param.no_bias) { + if (this->in_grad_bias == nullptr) + this->in_grad_bias = std::shared_ptr( + new mkldnn::memory(ipBwdWeights_pd.diff_bias_primitive_desc(), + in_grad_bias.get_data_handle())); + else + this->in_grad_bias->set_data_handle(in_grad_bias.get_data_handle()); + + if (this->ipBwdWeight == nullptr) + this->ipBwdWeight = std::shared_ptr( + new mkldnn::inner_product_backward_weights( + this->ipBwdWeights_pd, mkldnn::primitive::at(*this->data), + mkldnn::primitive::at(*this->out_grad), *this->in_grad_weight, *this->in_grad_bias)); + } else { + if (this->ipBwdWeight == nullptr) + this->ipBwdWeight = std::shared_ptr( + new mkldnn::inner_product_backward_weights( + this->ipBwdWeights_pd, mkldnn::primitive::at(*this->data), + mkldnn::primitive::at(*this->out_grad), *this->in_grad_weight)); + } + } + + const mkldnn::inner_product_backward_data &GetBwdData() const { + return *ipBwdData; + } + + const mkldnn::inner_product_backward_weights &GetBwdWeights() const { + return *ipBwdWeight; + } +}; + +typedef ParamOpSign MKLDNNFullyconSignature; + +static inline MKLDNNFullyConnectBackward &GetFCBwd( + const FullyConnectedParam ¶m, const NDArray &data, + const NDArray &weight, const std::vector &in_grad, + const NDArray &out_grad, const std::vector &req, + const mkldnn::inner_product_forward::primitive_desc &ipFwd_pd) { +#if DMLC_CXX11_THREAD_LOCAL + static thread_local std::unordered_map + bwdDatas; +#else + static MX_THREAD_LOCAL std::unordered_map + bwdDatas; +#endif + MKLDNNFullyconSignature key(param); + key.AddSign(data); + key.AddSign(weight); + key.AddSign(in_grad); + key.AddSign(out_grad); + + auto it = bwdDatas.find(key); + if (it == bwdDatas.end()) { + MKLDNNFullyConnectBackward bwdData(param, data, weight, in_grad, out_grad, + req, ipFwd_pd); + auto ins_ret = bwdDatas.insert( + std::pair( + key, bwdData)); + CHECK(ins_ret.second); + it = ins_ret.first; + } + return it->second; +} + void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, const std::vector &inputs, const std::vector &req, @@ -161,41 +399,34 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, param.no_bias ? nullptr : &in_grad[fullc::kBias], GetMemDesc(out_grad), ctx.is_train); CHECK_NE(req[fullc::kWeight], kWriteInplace) << "cannot write weight inplace"; + MKLDNNFullyConnectBackward &FCBwd = + GetFCBwd(param, data, weight, in_grad, out_grad, req, ipFwd_pd); if (req[fullc::kData]) { - mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetIpBwdData( - data, weight, out_grad, ipFwd_pd); auto out_grad_mem = out_grad.GetMKLDNNDataReorder( - ipBwdData_pd.diff_dst_primitive_desc()); - auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_primitive_desc()); + FCBwd.ipBwdData_pd.diff_dst_primitive_desc()); + auto weight_mem = weight.GetMKLDNNDataReorder(FCBwd.ipBwdData_pd.weights_primitive_desc()); auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData], - ipBwdData_pd.diff_src_primitive_desc(), + FCBwd.ipBwdData_pd.diff_src_primitive_desc(), req[fullc::kData]); - MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_data( - ipBwdData_pd, *out_grad_mem, *weight_mem, *in_grad_mem.second)); - CommitOutput(in_grad[fullc::kData], in_grad_mem); + FCBwd.SetNewMemData(*out_grad_mem, *weight_mem, *in_grad_mem.second); + MKLDNNStream::Get()->RegisterPrim(FCBwd.GetBwdData()); } if (req[fullc::kWeight]) { - mkldnn::inner_product_backward_weights::primitive_desc ipBwdWeights_pd - = GetIPBwdWeights(data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias], - out_grad, ipFwd_pd); auto out_grad_mem = out_grad.GetMKLDNNDataReorder( - ipBwdWeights_pd.diff_dst_primitive_desc()); - auto data_mem = data.GetMKLDNNDataReorder(ipBwdWeights_pd.src_primitive_desc()); - auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[fullc::kWeight], - ipBwdWeights_pd.diff_weights_primitive_desc(), - req[fullc::kWeight]); + FCBwd.ipBwdWeights_pd.diff_dst_primitive_desc()); + auto data_mem = + data.GetMKLDNNDataReorder(FCBwd.ipBwdWeights_pd.src_primitive_desc()); + auto in_grad_weight = CreateMKLDNNWeightGrad( + in_grad[fullc::kWeight], + FCBwd.ipBwdWeights_pd.diff_weights_primitive_desc(), + req[fullc::kWeight]); mkldnn_output_t in_grad_bias; - if (param.no_bias) { - MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_weights( - ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second)); - } else { - in_grad_bias = CreateMKLDNNMem(in_grad[fullc::kBias], - ipBwdWeights_pd.diff_bias_primitive_desc(), + in_grad_bias = CreateMKLDNNMem(in_grad[fullc::kBias], + FCBwd.ipBwdWeights_pd.diff_bias_primitive_desc(), req[fullc::kBias]); - MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_weights( - ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second, - *in_grad_bias.second)); - } + FCBwd.SetNewWeightMem(param, *data_mem, *out_grad_mem, + *in_grad_weight.second, *in_grad_bias.second); + MKLDNNStream::Get()->RegisterPrim(FCBwd.GetBwdWeights()); CommitOutput(in_grad[fullc::kWeight], in_grad_weight); CommitOutput(in_grad[fullc::kBias], in_grad_bias); }