Skip to content

Commit

Permalink
fix quantize graph pass (apache#14605)
Browse files Browse the repository at this point in the history
  • Loading branch information
xinyu-intel authored and haohuw committed Jun 23, 2019
1 parent feac703 commit 64a8db7
Show file tree
Hide file tree
Showing 3 changed files with 154 additions and 6 deletions.
12 changes: 8 additions & 4 deletions src/operator/quantization/quantize_graph_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ Graph QuantizeGraph(Graph &&src) {
(mirror_node->op() != Op::Get("_contrib_dequantize"))) {
// here we calculate the output number (exclude min/max, in order to
// calculate min/max index from mirror node) based on assumption that
// there is only 1min and 1max output from mirror node (which is
// there is only 1 min and 1 max output from mirror node (which is
// currently true)
size_t num_outputs = mirror_node->num_outputs() - 2;
uint32_t min_index = num_outputs + 2 * e.index;
Expand Down Expand Up @@ -297,9 +297,13 @@ Graph QuantizeGraph(Graph &&src) {
// Only insert dequantize for those Ops supports quantize and not excluded.
NodePtr mirror_node = mirror_map.at(e.node.get());
NodeEntry mirror_entry = NodeEntry{mirror_node, e.index, e.version};
size_t num_inputs = e.node->num_inputs();
uint32_t min_index = num_inputs + 2 * e.index;
uint32_t max_index = num_inputs + 2 * e.index + 1;
// here we calculate the output number (exclude min/max, in order to
// calculate min/max index from mirror node) based on assumption that
// there is only 1 min and 1 max output from mirror node (which is
// currently true)
size_t num_outputs = e.node->num_outputs();
uint32_t min_index = num_outputs + 2 * e.index;
uint32_t max_index = num_outputs + 2 * e.index + 1;

NodePtr dequantize_node = CreateNode("_contrib_dequantize",
e.node->attrs.name + "_dequantize");
Expand Down
53 changes: 51 additions & 2 deletions tests/python/mkl/test_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def check_qsym_scale_align(qsym):


def check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_shape):
mod = mx.mod.Module(symbol=qsym, context=mx.current_context())
mod = Module(symbol=qsym, context=mx.current_context())
mod.bind(for_training=False,
data_shapes=[('data', data_shape)],
label_shapes=[('softmax_label', label_shape)])
Expand All @@ -96,7 +96,7 @@ def check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_
return mod.get_outputs()

def check_qsym_dummy_forward(qsym, batch, data_shape, label_shape):
mod = mx.mod.Module(symbol=qsym, context=mx.current_context())
mod = Module(symbol=qsym, context=mx.current_context())
mod.bind(for_training=False,
data_shapes=[('data', data_shape)],
label_shapes=[('softmax_label', label_shape)])
Expand Down Expand Up @@ -184,6 +184,55 @@ def check_quantize(sym, data_shape, out_type, name='conv',
for i in range(len(ref_out)):
assert_almost_equal(ref_out[i].asnumpy(), quantized_out[i].asnumpy(), atol = 1)

@with_seed()
def check_quantize_whole_model_with_forward():
def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape):
mod = Module(symbol=qsym, label_names=None, context=mx.current_context())
mod.bind(for_training=False,
data_shapes=[('data', data_shape)])
mod.set_params(qarg_params, qaux_params)
data = [mx.random.uniform(-1.0, 1.0, shape=shape) for _, shape in mod.data_shapes]
batch = mx.io.DataBatch(data, [])
mod.forward(batch, is_train=False)
for output in mod.get_outputs():
output.wait_to_read()

def check_quantize_whole_model(out_type):
batch_size = 4
data_shape = (batch_size, 4, 10, 10)
data = mx.sym.Variable('data')
conv0 = mx.sym.Convolution(data, kernel=(1, 1), num_filter=16, name='conv0')
sym = mx.sym.Convolution(conv0, kernel=(1, 1), num_filter=16, name='conv1')
sym_sg = sym.get_backend_symbol('MKLDNN')
mod = Module(symbol=sym, label_names=[])
mod.bind(for_training=False,
data_shapes=[('data', data_shape)])

mod.init_params(mx.init.Normal(0.5))
arg_params, aux_params = mod.get_params()

excluded_sym_names = []

calib_data = mx.nd.random.uniform(shape=data_shape)
calib_data = NDArrayIter(data=calib_data)
calib_data = DummyIter(calib_data)
calib_layer = lambda name: name.endswith('_output')
qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym_sg,
arg_params=arg_params,
aux_params=aux_params,
ctx=mx.current_context(),
excluded_sym_names=excluded_sym_names,
quantized_dtype=out_type,
calib_mode='naive',
calib_data=calib_data,
calib_layer=calib_layer,
num_calib_examples=5)
qsym = qsym.get_backend_symbol('MKLDNN_POST_QUANTIZE')
check_qsym_forward(qsym, qarg_params, qaux_params, data_shape)

for qdtype in ['uint8', 'int8', 'auto']:
check_quantize_whole_model(qdtype)

@with_seed()
def check_fusion(sym, data_shape, attrs_op, name='conv', check_quantization=True):
op_name = config[name][OP_NAME]
Expand Down
95 changes: 95 additions & 0 deletions tests/python/quantization/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,101 @@ def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape):
for qdtype in ['int8', 'uint8']:
check_quantize_model(qdtype)

@with_seed()
def test_quantize_conv_with_forward():
def check_quantize_model(qdtype):
if is_test_for_native_cpu():
print('skipped testing test_quantize_model_with_forward for native cpu since it is not supported yet')
return
elif qdtype == 'int8' and is_test_for_mkldnn():
print('skipped testing test_quantize_model_with_forward for mkldnn cpu int8 since it is not supported yet')
return
elif qdtype == 'uint8' and is_test_for_gpu():
print('skipped testing test_quantize_model_with_forward for gpu uint8 since it is not supported yet')
return

def check_params(params, qparams, qsym=None):
if qsym is None:
assert len(params) == len(qparams)
for k, v in params.items():
assert k in qparams
assert same(v.asnumpy(), qparams[k].asnumpy())
else:
qparams_ground_truth = mx.contrib.quant._quantize_params(qsym, params, th_dict = {})
assert len(qparams) == len(qparams_ground_truth)
for k, v in qparams_ground_truth.items():
assert k in qparams
assert same(v.asnumpy(), qparams[k].asnumpy())

def check_qsym_calibrated(qsym):
attrs = qsym.attr_dict()
for k, v in attrs.items():
if k.find('requantize_') != -1:
assert 'min_calib_range' in v
assert 'max_calib_range' in v

def check_qsym_qdtype(qsym, qdtype):
attrs = qsym.attr_dict()
for k, v in attrs.items():
if k.find('_quantize') != -1:
assert 'out_type' in v
assert v['out_type'] == qdtype

def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape):
mod = mx.mod.Module(symbol=qsym, label_names=None, context=mx.current_context())
mod.bind(for_training=False,
data_shapes=[('data', data_shape)])
mod.set_params(qarg_params, qaux_params)
data = [mx.random.uniform(-1.0, 1.0, shape=shape) for _, shape in mod.data_shapes]
batch = mx.io.DataBatch(data, [])
mod.forward(batch, is_train=False)
for output in mod.get_outputs():
output.wait_to_read()

batch_size = 4
dshape = (batch_size, 4, 10, 10)
data = mx.sym.Variable('data')
sym = mx.sym.Convolution(data, kernel=(1, 1), num_filter=16, name='conv0')

mod = Module(symbol=sym, label_names=None)
mod.bind(data_shapes=[('data', dshape)])

mod.init_params()
arg_params, aux_params = mod.get_params()
excluded_sym_names = []

qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym,
arg_params=arg_params,
aux_params=aux_params,
excluded_sym_names=excluded_sym_names,
ctx=mx.current_context(),
quantized_dtype=qdtype,
calib_mode='none')
check_params(arg_params, qarg_params, qsym)
check_params(aux_params, qaux_params)
check_qsym_forward(qsym, qarg_params, qaux_params, dshape)

calib_data = mx.nd.random.uniform(shape=dshape)
calib_data = NDArrayIter(data=calib_data, batch_size=batch_size)
calib_data = DummyIter(calib_data)
qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym,
arg_params=arg_params,
aux_params=aux_params,
excluded_sym_names=excluded_sym_names,
ctx=mx.current_context(),
quantized_dtype=qdtype,
calib_mode='naive',
calib_data=calib_data,
num_calib_examples=20)
check_params(arg_params, qarg_params, qsym)
check_params(aux_params, qaux_params)
check_qsym_calibrated(qsym)
check_qsym_qdtype(qsym, qdtype)
check_qsym_forward(qsym, qarg_params, qaux_params, dshape)

for qdtype in ['uint8', 'int8']:
check_quantize_model(qdtype)

@with_seed()
def test_quantize_sym_with_calib():
sym = get_fp32_sym()
Expand Down

0 comments on commit 64a8db7

Please sign in to comment.