Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix mkldnn backend when using naive engine #15089

Merged
merged 5 commits into from
Jun 1, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/operator/nn/mkldnn/mkldnn_convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -411,11 +411,12 @@ void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam &param,
// For inference, we want to reorder the weight array so we don't need to
// reorder data every time.
if (weight.IsDefaultData()) {
weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(),
param.conv_param.num_group);
// We also need to modify the layout on the original weight array. The
// data conversion happens after the weight array is used.
weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_primitive_desc());
weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(),
param.conv_param.num_group);

} else {
weight_mem = weight.GetMKLDNNData();
CHECK(weight_mem->get_primitive_desc() == fwd->fwd_pd.weights_primitive_desc());
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/mkldnn/mkldnn_deconvolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,10 +262,10 @@ void MKLDNNDeconvForward::SetDataHandle(const DeconvolutionParam& param,
// For inference, we want to reorder the weight array so we don't need to
// reorder data every time.
if (weight.IsDefaultData()) {
weight_mem = GetWeights(weight, fwd_pd.weights_primitive_desc(), param.num_group);
// We also need to modify the layout on the original weight array. The
// data conversion happens after the weight array is used.
const_cast<NDArray&>(weight).MKLDNNDataReorderAsync(fwd_pd.weights_primitive_desc());
weight_mem = GetWeights(weight, fwd_pd.weights_primitive_desc(), param.num_group);
} else {
weight_mem = weight.GetMKLDNNData();
CHECK(weight_mem->get_primitive_desc() == fwd_pd.weights_primitive_desc());
Expand Down
2 changes: 1 addition & 1 deletion src/operator/nn/mkldnn/mkldnn_fully_connected.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,8 @@ void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam &full_param,
weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), 1);
} else {
if (weight.IsDefaultData()) {
weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), 1);
weight.MKLDNNDataReorderAsync(fwd->fwd_pd.weights_primitive_desc());
weight_mem = GetWeights(weight, fwd->fwd_pd.weights_primitive_desc(), 1);
} else {
weight_mem = weight.GetMKLDNNData();
CHECK(weight_mem->get_primitive_desc() == fwd->fwd_pd.weights_primitive_desc());
Expand Down
2 changes: 1 addition & 1 deletion src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ static void MKLDNNQuantizedConvForward(const nnvm::NodeAttrs& attrs,
// For inference, we want to reorder the weight array so we don't need to
// reorder data every time.
if (weight.IsDefaultData()) {
weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), param.num_group);
// We also need to modify the layout on the original weight array. The
// data conversion happens after the weight array is used.
weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_primitive_desc());
weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), param.num_group);
} else {
weight_mem = weight.GetMKLDNNData();
CHECK(weight_mem->get_primitive_desc() == fwd.fwd_pd.weights_primitive_desc());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ void MKLDNNQuantizedFullyConnectedForward(const nnvm::NodeAttrs &attrs,
const mkldnn::memory *weight_mem = nullptr;

if (weight.IsDefaultData()) {
weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), 1);
weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_primitive_desc());
weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), 1);
} else {
weight_mem = weight.GetMKLDNNData();
CHECK(weight_mem->get_primitive_desc() == fwd.fwd_pd.weights_primitive_desc());
Expand Down