From 5c7d151527013a8502630126fbd7d71fb71c6909 Mon Sep 17 00:00:00 2001 From: rongzha1 Date: Mon, 28 Oct 2019 11:04:45 +0800 Subject: [PATCH] [mkldnn-v1.0]rm int8 sum workaround (#16623) * rm int8 sum workaround due to mkldnn lib update * simple dims asignments in mkldnn_quantized_elemwise_add.cc --- .../quantization/mkldnn/mkldnn_quantized_elemwise_add.cc | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_elemwise_add.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_elemwise_add.cc index e1e76817f875..2078ac4fead8 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_elemwise_add.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_elemwise_add.cc @@ -161,8 +161,11 @@ static void MKLDNNQuantizedElemwiseAddForward(const nnvm::NodeAttrs& attrs, cons std::vector in_desc; in_desc.push_back(dataA_mem->get_desc()); in_desc.push_back(dataB_mem->get_desc()); - auto output_desc = dataA_mem->get_desc(); - output_desc.data.data_type = static_cast(output_data_type); + const auto in_shape = in_data[quantized_elemwise_add_enum::kDataA].shape(); + mkldnn::memory::dims i_dims(in_shape.begin(), in_shape.end()); + auto output_desc = mkldnn::memory::desc(i_dims, + output_data_type, + mkldnn::memory::format_tag::any); mkldnn::sum::primitive_desc pdesc(output_desc, scales, in_desc, engine); auto mem = CreateMKLDNNMem(out_data[quantized_elemwise_add_enum::kOut], pdesc.dst_desc(),