diff --git a/src/operator/quantization/mkldnn/mkldnn_quantized_concat.cc b/src/operator/quantization/mkldnn/mkldnn_quantized_concat.cc index d9e884e82806..2a4c6d612e65 100644 --- a/src/operator/quantization/mkldnn/mkldnn_quantized_concat.cc +++ b/src/operator/quantization/mkldnn/mkldnn_quantized_concat.cc @@ -64,22 +64,32 @@ static void MKLDNNQuantizedConcatForward(const nnvm::NodeAttrs& attrs, const OpC std::vector data_mem; // new_data_mem is for auto-free new created mkldnn memory std::vector> new_data_mem; + const auto out_dtype = out_data[quantized_concat_enum::kOut].dtype(); for (int i = 0; i < param_.num_args; ++i) { auto i_scale = GetScale(in_data[i], data_min[i], data_max[i]); if (i_scale == out_scale) { + CHECK(in_data[i].dtype() == out_dtype); auto mem = in_data[i].GetMKLDNNData(); data_mem.push_back(mem); data_md.push_back(mem->get_primitive_desc()); } else { auto mem = in_data[i].GetMKLDNNData(); auto pd = mem->get_primitive_desc(); + if (in_data[i].dtype() != out_dtype) { + auto mem_desc = pd.desc(); + mkldnn::memory::desc new_md( + mkldnn::memory::dims(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims), + get_mkldnn_type(out_dtype), static_cast(mem_desc.data.format)); + pd = mkldnn::memory::primitive_desc(new_md, CpuEngine::Get()->get_engine()); + } const auto rescaled_mem = std::make_shared(pd); new_data_mem.push_back(rescaled_mem); std::vector reorder_scale = {out_scale / i_scale}; primitive_attr reorder_attr; reorder_attr.set_int_output_round_mode(round_mode::round_nearest); reorder_attr.set_output_scales(0, reorder_scale); - const auto reorder_pd = mkldnn::reorder::primitive_desc(pd, pd, reorder_attr); + const auto reorder_pd = + mkldnn::reorder::primitive_desc(mem->get_primitive_desc(), pd, reorder_attr); MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *mem, *rescaled_mem)); data_mem.push_back(rescaled_mem.get()); data_md.push_back(pd); diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py index b25fefc6cc0e..563fff1a6aa1 100644 --- a/tests/python/mkl/test_subgraph.py +++ b/tests/python/mkl/test_subgraph.py @@ -401,6 +401,15 @@ def single_concat(data_shape, input_num, dim): concat = mx.symbol.Concat(*inputs, name="concat", dim=dim) return concat +def single_concat_pos_neg(data_shape): + data, weight = head_symbol(data_shape) + conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=4, + kernel=(1, 1), stride=(1, 1), no_bias=True) + relu = mx.symbol.Activation(data=conv, name='relu', act_type='relu') + inputs = [data, relu] + concat = mx.symbol.Concat(*inputs, name="concat", dim=1) + return concat + # concat scale alignment case def concat_scale_align(data_shape): data, weight = head_symbol(data_shape) @@ -738,6 +747,8 @@ def test_pos_single_concat(): net = single_concat(data_shape, 4, 3) check_quantize(net, data_shape, out_type, name='conv', check_calibration=False) check_quantize(net, data_shape, out_type, name='conv', check_calibration=False, gluon_forward=True) + net = single_concat_pos_neg(data_shape) + check_quantize(net, data_shape, out_type, name='', check_calibration=False) @with_seed() def test_pos_concat_scale_align():