Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Enable primitive allocation cache for _backward_Fully_Connected
Browse files Browse the repository at this point in the history
Change-Id: I8347527ec1271b1518921a74e3581d7d84187429
  • Loading branch information
Huang, Zhiyuan authored and ZhennanQin committed Jun 21, 2018
1 parent 972219e commit 2f64725
Show file tree
Hide file tree
Showing 2 changed files with 286 additions and 38 deletions.
17 changes: 17 additions & 0 deletions src/operator/nn/fully_connected-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ struct FullyConnectedParam : public dmlc::Parameter<FullyConnectedParam> {
DMLC_DECLARE_FIELD(flatten).set_default(true)
.describe("Whether to collapse all but the first axis of the input data tensor.");
}
bool operator==(const FullyConnectedParam& other) const {
return this->num_hidden == other.num_hidden &&
this->no_bias == other.no_bias &&
this->flatten == other.flatten;
}
};

template<typename xpu, typename DType>
Expand Down Expand Up @@ -227,4 +232,16 @@ void FullyConnectedGradCompute(const nnvm::NodeAttrs& attrs,

} // namespace op
} // namespace mxnet
namespace std {
template<>
struct hash<mxnet::op::FullyConnectedParam> {
size_t operator()(const mxnet::op::FullyConnectedParam& val) {
size_t ret = 0;
ret = dmlc::HashCombine(ret, val.num_hidden);
ret = dmlc::HashCombine(ret, val.no_bias);
ret = dmlc::HashCombine(ret, val.flatten);
return ret;
}
};
} // namespace std
#endif // MXNET_OPERATOR_NN_FULLY_CONNECTED_INL_H_
307 changes: 269 additions & 38 deletions src/operator/nn/mkldnn/mkldnn_fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,101 @@ inline static mkldnn::inner_product_backward_weights::primitive_desc GetIPBwdWei
}
}

class MKLDNNFullyConnectForward {
std::shared_ptr<mkldnn::memory> data;
std::shared_ptr<mkldnn::memory> weight;
std::shared_ptr<mkldnn::memory> out;
std::shared_ptr<mkldnn::memory> bias;
std::shared_ptr<mkldnn::inner_product_forward> ipFwd;

public:
mkldnn::inner_product_forward::primitive_desc ipFwd_pd;

MKLDNNFullyConnectForward(const FullyConnectedParam &param, bool is_train,
const NDArray &data, const NDArray &weight,
const NDArray *bias,
const mkldnn::memory::desc &output)
: ipFwd_pd(GetIPFwd(data, weight, bias, output, is_train)) {}

void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight,
const mkldnn::memory *bias, const mkldnn::memory &output) {
if (this->data == nullptr)
this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
ipFwd_pd.src_primitive_desc(), data.get_data_handle()));
else
this->data->set_data_handle(data.get_data_handle());

if (this->weight == nullptr)
this->weight = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
ipFwd_pd.weights_primitive_desc(), weight.get_data_handle()));
else
this->weight->set_data_handle(weight.get_data_handle());

if (this->out == nullptr)
this->out = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
ipFwd_pd.dst_primitive_desc(), output.get_data_handle()));
else
this->out->set_data_handle(output.get_data_handle());

if (bias != nullptr) {
if (this->bias == nullptr)
this->bias = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
ipFwd_pd.bias_primitive_desc(), bias->get_data_handle()));
else
this->bias->set_data_handle(bias->get_data_handle());
if (this->ipFwd == nullptr)
this->ipFwd = std::shared_ptr<mkldnn::inner_product_forward>(
new mkldnn::inner_product_forward(
ipFwd_pd, mkldnn::primitive::at(*this->data),
mkldnn::primitive::at(*this->weight),
mkldnn::primitive::at(*this->bias), *this->out));
} else if (this->ipFwd == nullptr) {
this->ipFwd = std::shared_ptr<mkldnn::inner_product_forward>(
new mkldnn::inner_product_forward(
ipFwd_pd, mkldnn::primitive::at(*this->data),
mkldnn::primitive::at(*this->weight), *this->out));
}
}

const mkldnn::inner_product_forward &GetIpFwd() const {
return *ipFwd;
}
};

typedef ParamOpSign<FullyConnectedParam> MKLDNNFullyconSignature;

static inline MKLDNNFullyConnectForward &GetFCFwd(
const nnvm::NodeAttrs &attrs, const NDArray &data, const NDArray &weight,
const NDArray *bias, const mkldnn::memory::desc &output,
const bool is_train) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<MKLDNNFullyconSignature,
MKLDNNFullyConnectForward, OpHash> fcFwds;
#else
static MX_THREAD_LOCAL std::unordered_map<MKLDNNFullyconSignature,
MKLDNNFullyConnectForward, OpHash> fcFwds;
#endif
const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
MKLDNNFullyconSignature key(param);
key.AddSign(data);
key.AddSign(weight);
key.AddSign(is_train);

if (bias)
key.AddSign(*bias);

auto it = fcFwds.find(key);
if (it == fcFwds.end()) {
MKLDNNFullyConnectForward fcFwd(param, is_train, data, weight, bias,
output);
auto ins_ret = fcFwds.insert(
std::pair<MKLDNNFullyconSignature, MKLDNNFullyConnectForward>(key, fcFwd));
CHECK(ins_ret.second);
it = ins_ret.first;
}
return it->second;
}

void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
Expand Down Expand Up @@ -112,25 +207,168 @@ void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data[fullc::kOut].dtype()),
mkldnn::memory::format::any);
}

mkldnn::inner_product_forward::primitive_desc ipFwd_pd = GetIPFwd(data, weight,
param.no_bias ? nullptr : &in_data[fullc::kBias], out_md, ctx.is_train);
auto data_mem = data.GetMKLDNNDataReorder(ipFwd_pd.src_primitive_desc());
auto weight_mem = weight.GetMKLDNNDataReorder(ipFwd_pd.weights_primitive_desc());
MKLDNNFullyConnectForward &FCFwd =
GetFCFwd(attrs, data, weight, param.no_bias ? nullptr : &in_data[fullc::kBias],
out_md, ctx.is_train);
auto data_mem = data.GetMKLDNNDataReorder(FCFwd.ipFwd_pd.src_primitive_desc());
auto weight_mem = weight.GetMKLDNNDataReorder(FCFwd.ipFwd_pd.weights_primitive_desc());
auto out_mem = CreateMKLDNNMem(out_data[fullc::kOut],
ipFwd_pd.dst_primitive_desc(), req[fullc::kOut]);
if (param.no_bias) {
MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_forward(
ipFwd_pd, *data_mem, *weight_mem, *out_mem.second));
FCFwd.ipFwd_pd.dst_primitive_desc(), req[fullc::kOut]);
if (!param.no_bias) {
auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder(
FCFwd.ipFwd_pd.bias_primitive_desc());
FCFwd.SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second);
} else {
auto bias_mem = in_data[fullc::kBias].GetMKLDNNDataReorder(ipFwd_pd.bias_primitive_desc());
MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_forward(ipFwd_pd,
*data_mem, *weight_mem, *bias_mem, *out_mem.second));
FCFwd.SetNewMem(*data_mem, *weight_mem, nullptr, *out_mem.second);
}
MKLDNNStream::Get()->RegisterPrim(FCFwd.GetIpFwd());
CommitOutput(out_data[fullc::kOut], out_mem);
MKLDNNStream::Get()->Submit();
}

class MKLDNNFullyConnectBackward {
std::shared_ptr<mkldnn::inner_product_backward_data> ipBwdData;
std::shared_ptr<mkldnn::inner_product_backward_weights> ipBwdWeight;
std::shared_ptr<mkldnn::memory> out_grad;
std::shared_ptr<mkldnn::memory> weight;
std::shared_ptr<mkldnn::memory> in_grad;
std::shared_ptr<mkldnn::memory> data;
std::shared_ptr<mkldnn::memory> in_grad_weight;
std::shared_ptr<mkldnn::memory> in_grad_bias;

public:
mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd;
mkldnn::inner_product_backward_weights::primitive_desc ipBwdWeights_pd;

public:
MKLDNNFullyConnectBackward(
const FullyConnectedParam &param, const NDArray &data,
const NDArray &weight, const std::vector<NDArray> &in_grad,
const NDArray &out_grad, const std::vector<OpReqType> &req,
const mkldnn::inner_product_forward::primitive_desc &ipFwd_pd)
: ipBwdData_pd(GetIpBwdData(data, weight, out_grad, ipFwd_pd)),
ipBwdWeights_pd(GetIPBwdWeights(
data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias],
out_grad, ipFwd_pd)) {}

void SetNewMemData(const mkldnn::memory &out_grad,
const mkldnn::memory &weight,
const mkldnn::memory &in_grad) {
if (this->out_grad == nullptr)
this->out_grad = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
ipBwdData_pd.diff_dst_primitive_desc(), out_grad.get_data_handle()));
else
this->out_grad->set_data_handle(out_grad.get_data_handle());

if (this->weight == nullptr)
this->weight = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
ipBwdData_pd.weights_primitive_desc(), weight.get_data_handle()));
else
this->weight->set_data_handle(weight.get_data_handle());

if (this->in_grad == nullptr)
this->in_grad = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
ipBwdData_pd.diff_src_primitive_desc(), in_grad.get_data_handle()));
else
this->in_grad->set_data_handle(in_grad.get_data_handle());

if (this->ipBwdData == nullptr)
this->ipBwdData = std::shared_ptr<mkldnn::inner_product_backward_data>(
new mkldnn::inner_product_backward_data(
this->ipBwdData_pd, mkldnn::primitive::at(*this->out_grad),
mkldnn::primitive::at(*this->weight), *this->in_grad));
}

void SetNewWeightMem(const FullyConnectedParam &param,
const mkldnn::memory &data,
const mkldnn::memory &out_grad,
const mkldnn::memory &in_grad_weight,
const mkldnn::memory &in_grad_bias) {
if (this->data == nullptr)
this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
ipBwdWeights_pd.src_primitive_desc(), data.get_data_handle()));
else
this->data->set_data_handle(data.get_data_handle());

if (this->out_grad == nullptr)
this->out_grad = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
ipBwdWeights_pd.diff_dst_primitive_desc(), out_grad.get_data_handle()));
else
this->out_grad->set_data_handle(out_grad.get_data_handle());

if (this->in_grad_weight == nullptr)
this->in_grad_weight = std::shared_ptr<mkldnn::memory>(
new mkldnn::memory(ipBwdWeights_pd.diff_weights_primitive_desc(),
in_grad_weight.get_data_handle()));
else
this->in_grad_weight->set_data_handle(in_grad_weight.get_data_handle());

if (!param.no_bias) {
if (this->in_grad_bias == nullptr)
this->in_grad_bias = std::shared_ptr<mkldnn::memory>(
new mkldnn::memory(ipBwdWeights_pd.diff_bias_primitive_desc(),
in_grad_bias.get_data_handle()));
else
this->in_grad_bias->set_data_handle(in_grad_bias.get_data_handle());

if (this->ipBwdWeight == nullptr)
this->ipBwdWeight = std::shared_ptr<mkldnn::inner_product_backward_weights>(
new mkldnn::inner_product_backward_weights(
this->ipBwdWeights_pd, mkldnn::primitive::at(*this->data),
mkldnn::primitive::at(*this->out_grad), *this->in_grad_weight, *this->in_grad_bias));
} else {
if (this->ipBwdWeight == nullptr)
this->ipBwdWeight = std::shared_ptr<mkldnn::inner_product_backward_weights>(
new mkldnn::inner_product_backward_weights(
this->ipBwdWeights_pd, mkldnn::primitive::at(*this->data),
mkldnn::primitive::at(*this->out_grad), *this->in_grad_weight));
}
}

const mkldnn::inner_product_backward_data &GetBwdData() const {
return *ipBwdData;
}

const mkldnn::inner_product_backward_weights &GetBwdWeights() const {
return *ipBwdWeight;
}
};

typedef ParamOpSign<FullyConnectedParam> MKLDNNFullyconSignature;

static inline MKLDNNFullyConnectBackward &GetFCBwd(
const FullyConnectedParam &param, const NDArray &data,
const NDArray &weight, const std::vector<NDArray> &in_grad,
const NDArray &out_grad, const std::vector<OpReqType> &req,
const mkldnn::inner_product_forward::primitive_desc &ipFwd_pd) {
#if DMLC_CXX11_THREAD_LOCAL
static thread_local std::unordered_map<MKLDNNFullyconSignature,
MKLDNNFullyConnectBackward, OpHash>
bwdDatas;
#else
static MX_THREAD_LOCAL std::unordered_map<MKLDNNFullyconSignature,
MKLDNNFullyConnectBackward, OpHash>
bwdDatas;
#endif
MKLDNNFullyconSignature key(param);
key.AddSign(data);
key.AddSign(weight);
key.AddSign(in_grad);
key.AddSign(out_grad);

auto it = bwdDatas.find(key);
if (it == bwdDatas.end()) {
MKLDNNFullyConnectBackward bwdData(param, data, weight, in_grad, out_grad,
req, ipFwd_pd);
auto ins_ret = bwdDatas.insert(
std::pair<MKLDNNFullyconSignature, MKLDNNFullyConnectBackward>(
key, bwdData));
CHECK(ins_ret.second);
it = ins_ret.first;
}
return it->second;
}

void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
Expand Down Expand Up @@ -161,41 +399,34 @@ void MKLDNNFCBackward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
param.no_bias ? nullptr : &in_grad[fullc::kBias], GetMemDesc(out_grad), ctx.is_train);

CHECK_NE(req[fullc::kWeight], kWriteInplace) << "cannot write weight inplace";
MKLDNNFullyConnectBackward &FCBwd =
GetFCBwd(param, data, weight, in_grad, out_grad, req, ipFwd_pd);
if (req[fullc::kData]) {
mkldnn::inner_product_backward_data::primitive_desc ipBwdData_pd = GetIpBwdData(
data, weight, out_grad, ipFwd_pd);
auto out_grad_mem = out_grad.GetMKLDNNDataReorder(
ipBwdData_pd.diff_dst_primitive_desc());
auto weight_mem = weight.GetMKLDNNDataReorder(ipBwdData_pd.weights_primitive_desc());
FCBwd.ipBwdData_pd.diff_dst_primitive_desc());
auto weight_mem = weight.GetMKLDNNDataReorder(FCBwd.ipBwdData_pd.weights_primitive_desc());
auto in_grad_mem = CreateMKLDNNMem(in_grad[fullc::kData],
ipBwdData_pd.diff_src_primitive_desc(),
FCBwd.ipBwdData_pd.diff_src_primitive_desc(),
req[fullc::kData]);
MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_data(
ipBwdData_pd, *out_grad_mem, *weight_mem, *in_grad_mem.second));
CommitOutput(in_grad[fullc::kData], in_grad_mem);
FCBwd.SetNewMemData(*out_grad_mem, *weight_mem, *in_grad_mem.second);
MKLDNNStream::Get()->RegisterPrim(FCBwd.GetBwdData());
}
if (req[fullc::kWeight]) {
mkldnn::inner_product_backward_weights::primitive_desc ipBwdWeights_pd
= GetIPBwdWeights(data, weight, param.no_bias ? nullptr : &in_grad[fullc::kBias],
out_grad, ipFwd_pd);
auto out_grad_mem = out_grad.GetMKLDNNDataReorder(
ipBwdWeights_pd.diff_dst_primitive_desc());
auto data_mem = data.GetMKLDNNDataReorder(ipBwdWeights_pd.src_primitive_desc());
auto in_grad_weight = CreateMKLDNNWeightGrad(in_grad[fullc::kWeight],
ipBwdWeights_pd.diff_weights_primitive_desc(),
req[fullc::kWeight]);
FCBwd.ipBwdWeights_pd.diff_dst_primitive_desc());
auto data_mem =
data.GetMKLDNNDataReorder(FCBwd.ipBwdWeights_pd.src_primitive_desc());
auto in_grad_weight = CreateMKLDNNWeightGrad(
in_grad[fullc::kWeight],
FCBwd.ipBwdWeights_pd.diff_weights_primitive_desc(),
req[fullc::kWeight]);
mkldnn_output_t in_grad_bias;
if (param.no_bias) {
MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_weights(
ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second));
} else {
in_grad_bias = CreateMKLDNNMem(in_grad[fullc::kBias],
ipBwdWeights_pd.diff_bias_primitive_desc(),
in_grad_bias = CreateMKLDNNMem(in_grad[fullc::kBias],
FCBwd.ipBwdWeights_pd.diff_bias_primitive_desc(),
req[fullc::kBias]);
MKLDNNStream::Get()->RegisterPrim(mkldnn::inner_product_backward_weights(
ipBwdWeights_pd, *data_mem, *out_grad_mem, *in_grad_weight.second,
*in_grad_bias.second));
}
FCBwd.SetNewWeightMem(param, *data_mem, *out_grad_mem,
*in_grad_weight.second, *in_grad_bias.second);
MKLDNNStream::Get()->RegisterPrim(FCBwd.GetBwdWeights());
CommitOutput(in_grad[fullc::kWeight], in_grad_weight);
CommitOutput(in_grad[fullc::kBias], in_grad_bias);
}
Expand Down

0 comments on commit 2f64725

Please sign in to comment.