diff --git a/src/engine/naive_engine.cc b/src/engine/naive_engine.cc index 93853c459298..7291f46b9552 100644 --- a/src/engine/naive_engine.cc +++ b/src/engine/naive_engine.cc @@ -142,6 +142,10 @@ class NaiveEngine final : public Engine { opr->opr_name); } +/*! + * \brief NaiveEngine's PushAsync was intentionally synchronous. + * User should not make any assumption about execution order when using async interface of any engine. + */ void PushAsync(AsyncFn exec_fun, Context exec_ctx, std::vector const& const_vars, diff --git a/src/operator/nn/mkldnn/mkldnn_convolution.cc b/src/operator/nn/mkldnn/mkldnn_convolution.cc index a394edeef841..6a91ae0d92a1 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_convolution.cc @@ -411,11 +411,12 @@ void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam ¶m, // For inference, we want to reorder the weight array so we don't need to // reorder data every time. if (weight.IsDefaultData()) { - weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), - param.conv_param.num_group); // We also need to modify the layout on the original weight array. The // data conversion happens after the weight array is used. weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_primitive_desc()); + weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), + param.conv_param.num_group); + } else { weight_mem = weight.GetMKLDNNData(); CHECK(weight_mem->get_primitive_desc() == fwd->fwd_pd.weights_primitive_desc()); diff --git a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc index aec5d13c5de9..87089f389e89 100644 --- a/src/operator/nn/mkldnn/mkldnn_deconvolution.cc +++ b/src/operator/nn/mkldnn/mkldnn_deconvolution.cc @@ -262,10 +262,11 @@ void MKLDNNDeconvForward::SetDataHandle(const DeconvolutionParam& param, // For inference, we want to reorder the weight array so we don't need to // reorder data every time. if (weight.IsDefaultData()) { - weight_mem = GetWeights(weight, fwd_pd.weights_primitive_desc(), param.num_group); - // We also need to modify the layout on the original weight array. The - // data conversion happens after the weight array is used. + // We also need to modify the layout on the original weight array. + // Don't switch below sequence because naive engine will executes + // pushAsync synchronously. const_cast(weight).MKLDNNDataReorderAsync(fwd_pd.weights_primitive_desc()); + weight_mem = GetWeights(weight, fwd_pd.weights_primitive_desc(), param.num_group); } else { weight_mem = weight.GetMKLDNNData(); CHECK(weight_mem->get_primitive_desc() == fwd_pd.weights_primitive_desc()); diff --git a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc index 03d7e62da399..1dfd2a95f338 100644 --- a/src/operator/nn/mkldnn/mkldnn_fully_connected.cc +++ b/src/operator/nn/mkldnn/mkldnn_fully_connected.cc @@ -253,8 +253,11 @@ void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam &full_param, weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), 1); } else { if (weight.IsDefaultData()) { - weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), 1); + // We also need to modify the layout on the original weight array. + // Don't switch below sequence because naive engine will executes + // pushAsync synchronously. weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_primitive_desc()); + weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), 1); } else { weight_mem = weight.GetMKLDNNData(); CHECK(weight_mem->get_primitive_desc() == fwd->fwd_pd.weights_primitive_desc()); diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc index 55028d8c8ccc..f81071704762 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc @@ -52,10 +52,11 @@ static void MKLDNNQuantizedConvForward(const nnvm::NodeAttrs& attrs, // For inference, we want to reorder the weight array so we don't need to // reorder data every time. if (weight.IsDefaultData()) { - weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), param.num_group); - // We also need to modify the layout on the original weight array. The - // data conversion happens after the weight array is used. + // We also need to modify the layout on the original weight array. + // Don't switch below sequence because naive engine will executes + // pushAsync synchronously. weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_primitive_desc()); + weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), param.num_group); } else { weight_mem = weight.GetMKLDNNData(); CHECK(weight_mem->get_primitive_desc() == fwd.fwd_pd.weights_primitive_desc()); diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc index e8abab22446e..aca129a56f3e 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_fully_connected.cc @@ -93,8 +93,11 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs &attrs, const mkldnn::memory *weight_mem = nullptr; if (weight.IsDefaultData()) { - weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), 1); + // We also need to modify the layout on the original weight array. + // Don't switch below sequence because naive engine will executes + // pushAsync synchronously. weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_primitive_desc()); + weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), 1); } else { weight_mem = weight.GetMKLDNNData(); CHECK(weight_mem->get_primitive_desc() == fwd.fwd_pd.weights_primitive_desc());