diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index af533978a6f5..5bd9e8af9038 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -265,7 +265,7 @@ Graph QuantizeGraph(Graph &&src) { (mirror_node->op() != Op::Get("_contrib_dequantize"))) { // here we calculate the output number (exclude min/max, in order to // calculate min/max index from mirror node) based on assumption that - // there is only 1min and 1max output from mirror node (which is + // there is only 1 min and 1 max output from mirror node (which is // currently true) size_t num_outputs = mirror_node->num_outputs() - 2; uint32_t min_index = num_outputs + 2 * e.index; @@ -297,9 +297,13 @@ Graph QuantizeGraph(Graph &&src) { // Only insert dequantize for those Ops supports quantize and not excluded. NodePtr mirror_node = mirror_map.at(e.node.get()); NodeEntry mirror_entry = NodeEntry{mirror_node, e.index, e.version}; - size_t num_inputs = e.node->num_inputs(); - uint32_t min_index = num_inputs + 2 * e.index; - uint32_t max_index = num_inputs + 2 * e.index + 1; + // here we calculate the output number (exclude min/max, in order to + // calculate min/max index from mirror node) based on assumption that + // there is only 1 min and 1 max output from mirror node (which is + // currently true) + size_t num_outputs = e.node->num_outputs(); + uint32_t min_index = num_outputs + 2 * e.index; + uint32_t max_index = num_outputs + 2 * e.index + 1; NodePtr dequantize_node = CreateNode("_contrib_dequantize", e.node->attrs.name + "_dequantize"); diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py index c8cf79e399fd..761eb47e56cb 100644 --- a/tests/python/mkl/test_subgraph.py +++ b/tests/python/mkl/test_subgraph.py @@ -85,7 +85,7 @@ def check_qsym_scale_align(qsym): def check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_shape): - mod = mx.mod.Module(symbol=qsym, context=mx.current_context()) + mod = Module(symbol=qsym, context=mx.current_context()) mod.bind(for_training=False, data_shapes=[('data', data_shape)], label_shapes=[('softmax_label', label_shape)]) @@ -96,7 +96,7 @@ def check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_ return mod.get_outputs() def check_qsym_dummy_forward(qsym, batch, data_shape, label_shape): - mod = mx.mod.Module(symbol=qsym, context=mx.current_context()) + mod = Module(symbol=qsym, context=mx.current_context()) mod.bind(for_training=False, data_shapes=[('data', data_shape)], label_shapes=[('softmax_label', label_shape)]) @@ -184,6 +184,55 @@ def check_quantize(sym, data_shape, out_type, name='conv', for i in range(len(ref_out)): assert_almost_equal(ref_out[i].asnumpy(), quantized_out[i].asnumpy(), atol = 1) +@with_seed() +def check_quantize_whole_model_with_forward(): + def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape): + mod = Module(symbol=qsym, label_names=None, context=mx.current_context()) + mod.bind(for_training=False, + data_shapes=[('data', data_shape)]) + mod.set_params(qarg_params, qaux_params) + data = [mx.random.uniform(-1.0, 1.0, shape=shape) for _, shape in mod.data_shapes] + batch = mx.io.DataBatch(data, []) + mod.forward(batch, is_train=False) + for output in mod.get_outputs(): + output.wait_to_read() + + def check_quantize_whole_model(out_type): + batch_size = 4 + data_shape = (batch_size, 4, 10, 10) + data = mx.sym.Variable('data') + conv0 = mx.sym.Convolution(data, kernel=(1, 1), num_filter=16, name='conv0') + sym = mx.sym.Convolution(conv0, kernel=(1, 1), num_filter=16, name='conv1') + sym_sg = sym.get_backend_symbol('MKLDNN') + mod = Module(symbol=sym, label_names=[]) + mod.bind(for_training=False, + data_shapes=[('data', data_shape)]) + + mod.init_params(mx.init.Normal(0.5)) + arg_params, aux_params = mod.get_params() + + excluded_sym_names = [] + + calib_data = mx.nd.random.uniform(shape=data_shape) + calib_data = NDArrayIter(data=calib_data) + calib_data = DummyIter(calib_data) + calib_layer = lambda name: name.endswith('_output') + qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym_sg, + arg_params=arg_params, + aux_params=aux_params, + ctx=mx.current_context(), + excluded_sym_names=excluded_sym_names, + quantized_dtype=out_type, + calib_mode='naive', + calib_data=calib_data, + calib_layer=calib_layer, + num_calib_examples=5) + qsym = qsym.get_backend_symbol('MKLDNN_POST_QUANTIZE') + check_qsym_forward(qsym, qarg_params, qaux_params, data_shape) + + for qdtype in ['uint8', 'int8', 'auto']: + check_quantize_whole_model(qdtype) + @with_seed() def check_fusion(sym, data_shape, attrs_op, name='conv', check_quantization=True): op_name = config[name][OP_NAME] diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index eedc867ce8d3..757df81e1607 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -677,6 +677,101 @@ def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape): for qdtype in ['int8', 'uint8']: check_quantize_model(qdtype) +@with_seed() +def test_quantize_conv_with_forward(): + def check_quantize_model(qdtype): + if is_test_for_native_cpu(): + print('skipped testing test_quantize_model_with_forward for native cpu since it is not supported yet') + return + elif qdtype == 'int8' and is_test_for_mkldnn(): + print('skipped testing test_quantize_model_with_forward for mkldnn cpu int8 since it is not supported yet') + return + elif qdtype == 'uint8' and is_test_for_gpu(): + print('skipped testing test_quantize_model_with_forward for gpu uint8 since it is not supported yet') + return + + def check_params(params, qparams, qsym=None): + if qsym is None: + assert len(params) == len(qparams) + for k, v in params.items(): + assert k in qparams + assert same(v.asnumpy(), qparams[k].asnumpy()) + else: + qparams_ground_truth = mx.contrib.quant._quantize_params(qsym, params, th_dict = {}) + assert len(qparams) == len(qparams_ground_truth) + for k, v in qparams_ground_truth.items(): + assert k in qparams + assert same(v.asnumpy(), qparams[k].asnumpy()) + + def check_qsym_calibrated(qsym): + attrs = qsym.attr_dict() + for k, v in attrs.items(): + if k.find('requantize_') != -1: + assert 'min_calib_range' in v + assert 'max_calib_range' in v + + def check_qsym_qdtype(qsym, qdtype): + attrs = qsym.attr_dict() + for k, v in attrs.items(): + if k.find('_quantize') != -1: + assert 'out_type' in v + assert v['out_type'] == qdtype + + def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape): + mod = mx.mod.Module(symbol=qsym, label_names=None, context=mx.current_context()) + mod.bind(for_training=False, + data_shapes=[('data', data_shape)]) + mod.set_params(qarg_params, qaux_params) + data = [mx.random.uniform(-1.0, 1.0, shape=shape) for _, shape in mod.data_shapes] + batch = mx.io.DataBatch(data, []) + mod.forward(batch, is_train=False) + for output in mod.get_outputs(): + output.wait_to_read() + + batch_size = 4 + dshape = (batch_size, 4, 10, 10) + data = mx.sym.Variable('data') + sym = mx.sym.Convolution(data, kernel=(1, 1), num_filter=16, name='conv0') + + mod = Module(symbol=sym, label_names=None) + mod.bind(data_shapes=[('data', dshape)]) + + mod.init_params() + arg_params, aux_params = mod.get_params() + excluded_sym_names = [] + + qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym, + arg_params=arg_params, + aux_params=aux_params, + excluded_sym_names=excluded_sym_names, + ctx=mx.current_context(), + quantized_dtype=qdtype, + calib_mode='none') + check_params(arg_params, qarg_params, qsym) + check_params(aux_params, qaux_params) + check_qsym_forward(qsym, qarg_params, qaux_params, dshape) + + calib_data = mx.nd.random.uniform(shape=dshape) + calib_data = NDArrayIter(data=calib_data, batch_size=batch_size) + calib_data = DummyIter(calib_data) + qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym, + arg_params=arg_params, + aux_params=aux_params, + excluded_sym_names=excluded_sym_names, + ctx=mx.current_context(), + quantized_dtype=qdtype, + calib_mode='naive', + calib_data=calib_data, + num_calib_examples=20) + check_params(arg_params, qarg_params, qsym) + check_params(aux_params, qaux_params) + check_qsym_calibrated(qsym) + check_qsym_qdtype(qsym, qdtype) + check_qsym_forward(qsym, qarg_params, qaux_params, dshape) + + for qdtype in ['uint8', 'int8']: + check_quantize_model(qdtype) + @with_seed() def test_quantize_sym_with_calib(): sym = get_fp32_sym()