Skip to content

Commit

Permalink
Fix quantized concat when inputs are mixed int8 and uint8 (apache#15693)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhennanQin authored and Ubuntu committed Aug 20, 2019
1 parent b29cb40 commit 3a0719a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 1 deletion.
12 changes: 11 additions & 1 deletion src/operator/quantization/mkldnn/mkldnn_quantized_concat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,22 +64,32 @@ static void MKLDNNQuantizedConcatForward(const nnvm::NodeAttrs& attrs, const OpC
std::vector<const mkldnn::memory*> data_mem;
// new_data_mem is for auto-free new created mkldnn memory
std::vector<std::shared_ptr<mkldnn::memory>> new_data_mem;
const auto out_dtype = out_data[quantized_concat_enum::kOut].dtype();
for (int i = 0; i < param_.num_args; ++i) {
auto i_scale = GetScale(in_data[i], data_min[i], data_max[i]);
if (i_scale == out_scale) {
CHECK(in_data[i].dtype() == out_dtype);
auto mem = in_data[i].GetMKLDNNData();
data_mem.push_back(mem);
data_md.push_back(mem->get_primitive_desc());
} else {
auto mem = in_data[i].GetMKLDNNData();
auto pd = mem->get_primitive_desc();
if (in_data[i].dtype() != out_dtype) {
auto mem_desc = pd.desc();
mkldnn::memory::desc new_md(
mkldnn::memory::dims(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims),
get_mkldnn_type(out_dtype), static_cast<mkldnn::memory::format>(mem_desc.data.format));
pd = mkldnn::memory::primitive_desc(new_md, CpuEngine::Get()->get_engine());
}
const auto rescaled_mem = std::make_shared<mkldnn::memory>(pd);
new_data_mem.push_back(rescaled_mem);
std::vector<float> reorder_scale = {out_scale / i_scale};
primitive_attr reorder_attr;
reorder_attr.set_int_output_round_mode(round_mode::round_nearest);
reorder_attr.set_output_scales(0, reorder_scale);
const auto reorder_pd = mkldnn::reorder::primitive_desc(pd, pd, reorder_attr);
const auto reorder_pd =
mkldnn::reorder::primitive_desc(mem->get_primitive_desc(), pd, reorder_attr);
MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *mem, *rescaled_mem));
data_mem.push_back(rescaled_mem.get());
data_md.push_back(pd);
Expand Down
11 changes: 11 additions & 0 deletions tests/python/mkl/test_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,15 @@ def single_concat(data_shape, input_num, dim):
concat = mx.symbol.Concat(*inputs, name="concat", dim=dim)
return concat

def single_concat_pos_neg(data_shape):
data, weight = head_symbol(data_shape)
conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=4,
kernel=(1, 1), stride=(1, 1), no_bias=True)
relu = mx.symbol.Activation(data=conv, name='relu', act_type='relu')
inputs = [data, relu]
concat = mx.symbol.Concat(*inputs, name="concat", dim=1)
return concat

# concat scale alignment case
def concat_scale_align(data_shape):
data, weight = head_symbol(data_shape)
Expand Down Expand Up @@ -738,6 +747,8 @@ def test_pos_single_concat():
net = single_concat(data_shape, 4, 3)
check_quantize(net, data_shape, out_type, name='conv', check_calibration=False)
check_quantize(net, data_shape, out_type, name='conv', check_calibration=False, gluon_forward=True)
net = single_concat_pos_neg(data_shape)
check_quantize(net, data_shape, out_type, name='', check_calibration=False)

@with_seed()
def test_pos_concat_scale_align():
Expand Down

0 comments on commit 3a0719a

Please sign in to comment.