Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix mkldnn backend when using naive engine #15089

Merged
merged 5 commits into from
Jun 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/engine/naive_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ class NaiveEngine final : public Engine {
opr->opr_name);
}

/*!
* \brief NaiveEngine's PushAsync was intentionally synchronous.
* User should not make any assumption about execution order when using async interface of any engine.
*/
void PushAsync(AsyncFn exec_fun,
Context exec_ctx,
std::vector<VarHandle> const& const_vars,
Expand Down
5 changes: 3 additions & 2 deletions src/operator/nn/mkldnn/mkldnn_convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -411,11 +411,12 @@ void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam &param,
// For inference, we want to reorder the weight array so we don't need to
// reorder data every time.
if (weight.IsDefaultData()) {
weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(),
param.conv_param.num_group);
// We also need to modify the layout on the original weight array. The
// data conversion happens after the weight array is used.
weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_primitive_desc());
weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(),
param.conv_param.num_group);

} else {
weight_mem = weight.GetMKLDNNData();
CHECK(weight_mem->get_primitive_desc() == fwd->fwd_pd.weights_primitive_desc());
Expand Down
7 changes: 4 additions & 3 deletions src/operator/nn/mkldnn/mkldnn_deconvolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,11 @@ void MKLDNNDeconvForward::SetDataHandle(const DeconvolutionParam& param,
// For inference, we want to reorder the weight array so we don't need to
// reorder data every time.
if (weight.IsDefaultData()) {
weight_mem = GetWeights(weight, fwd_pd.weights_primitive_desc(), param.num_group);
// We also need to modify the layout on the original weight array. The
// data conversion happens after the weight array is used.
// We also need to modify the layout on the original weight array.
// Don't switch below sequence because naive engine will executes
// pushAsync synchronously.
const_cast<NDArray&>(weight).MKLDNNDataReorderAsync(fwd_pd.weights_primitive_desc());
weight_mem = GetWeights(weight, fwd_pd.weights_primitive_desc(), param.num_group);
} else {
weight_mem = weight.GetMKLDNNData();
CHECK(weight_mem->get_primitive_desc() == fwd_pd.weights_primitive_desc());
Expand Down
5 changes: 4 additions & 1 deletion src/operator/nn/mkldnn/mkldnn_fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,11 @@ void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam &full_param,
weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), 1);
} else {
if (weight.IsDefaultData()) {
weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), 1);
// We also need to modify the layout on the original weight array.
// Don't switch below sequence because naive engine will executes
// pushAsync synchronously.
weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_primitive_desc());
weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), 1);
} else {
weight_mem = weight.GetMKLDNNData();
CHECK(weight_mem->get_primitive_desc() == fwd->fwd_pd.weights_primitive_desc());
Expand Down
7 changes: 4 additions & 3 deletions src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,11 @@ static void MKLDNNQuantizedConvForward(const nnvm::NodeAttrs& attrs,
// For inference, we want to reorder the weight array so we don't need to
// reorder data every time.
if (weight.IsDefaultData()) {
weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), param.num_group);
// We also need to modify the layout on the original weight array. The
// data conversion happens after the weight array is used.
// We also need to modify the layout on the original weight array.
// Don't switch below sequence because naive engine will executes
// pushAsync synchronously.
weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_primitive_desc());
weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), param.num_group);
} else {
weight_mem = weight.GetMKLDNNData();
CHECK(weight_mem->get_primitive_desc() == fwd.fwd_pd.weights_primitive_desc());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,11 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs &attrs,
const mkldnn::memory *weight_mem = nullptr;

if (weight.IsDefaultData()) {
weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), 1);
// We also need to modify the layout on the original weight array.
// Don't switch below sequence because naive engine will executes
// pushAsync synchronously.
weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_primitive_desc());
weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), 1);
} else {
weight_mem = weight.GetMKLDNNData();
CHECK(weight_mem->get_primitive_desc() == fwd.fwd_pd.weights_primitive_desc());
Expand Down