Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[Quantization] Fix quantize graph pass #14605

Merged
merged 1 commit into from
Apr 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions src/operator/quantization/quantize_graph_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ Graph QuantizeGraph(Graph &&src) {
(mirror_node->op() != Op::Get("_contrib_dequantize"))) {
// here we calculate the output number (exclude min/max, in order to
// calculate min/max index from mirror node) based on assumption that
// there is only 1min and 1max output from mirror node (which is
// there is only 1 min and 1 max output from mirror node (which is
// currently true)
size_t num_outputs = mirror_node->num_outputs() - 2;
uint32_t min_index = num_outputs + 2 * e.index;
Expand Down Expand Up @@ -297,9 +297,13 @@ Graph QuantizeGraph(Graph &&src) {
// Only insert dequantize for those Ops supports quantize and not excluded.
NodePtr mirror_node = mirror_map.at(e.node.get());
NodeEntry mirror_entry = NodeEntry{mirror_node, e.index, e.version};
size_t num_inputs = e.node->num_inputs();
uint32_t min_index = num_inputs + 2 * e.index;
uint32_t max_index = num_inputs + 2 * e.index + 1;
// here we calculate the output number (exclude min/max, in order to
// calculate min/max index from mirror node) based on assumption that
// there is only 1 min and 1 max output from mirror node (which is
// currently true)
size_t num_outputs = e.node->num_outputs();
uint32_t min_index = num_outputs + 2 * e.index;
uint32_t max_index = num_outputs + 2 * e.index + 1;

NodePtr dequantize_node = CreateNode("_contrib_dequantize",
e.node->attrs.name + "_dequantize");
Expand Down
53 changes: 51 additions & 2 deletions tests/python/mkl/test_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def check_qsym_scale_align(qsym):


def check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_shape):
mod = mx.mod.Module(symbol=qsym, context=mx.current_context())
mod = Module(symbol=qsym, context=mx.current_context())
mod.bind(for_training=False,
data_shapes=[('data', data_shape)],
label_shapes=[('softmax_label', label_shape)])
Expand All @@ -96,7 +96,7 @@ def check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_
return mod.get_outputs()

def check_qsym_dummy_forward(qsym, batch, data_shape, label_shape):
mod = mx.mod.Module(symbol=qsym, context=mx.current_context())
mod = Module(symbol=qsym, context=mx.current_context())
mod.bind(for_training=False,
data_shapes=[('data', data_shape)],
label_shapes=[('softmax_label', label_shape)])
Expand Down Expand Up @@ -184,6 +184,55 @@ def check_quantize(sym, data_shape, out_type, name='conv',
for i in range(len(ref_out)):
assert_almost_equal(ref_out[i].asnumpy(), quantized_out[i].asnumpy(), atol = 1)

@with_seed()
def check_quantize_whole_model_with_forward():
def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape):
mod = Module(symbol=qsym, label_names=None, context=mx.current_context())
mod.bind(for_training=False,
data_shapes=[('data', data_shape)])
mod.set_params(qarg_params, qaux_params)
data = [mx.random.uniform(-1.0, 1.0, shape=shape) for _, shape in mod.data_shapes]
batch = mx.io.DataBatch(data, [])
mod.forward(batch, is_train=False)
for output in mod.get_outputs():
output.wait_to_read()

def check_quantize_whole_model(out_type):
batch_size = 4
data_shape = (batch_size, 4, 10, 10)
data = mx.sym.Variable('data')
conv0 = mx.sym.Convolution(data, kernel=(1, 1), num_filter=16, name='conv0')
sym = mx.sym.Convolution(conv0, kernel=(1, 1), num_filter=16, name='conv1')
sym_sg = sym.get_backend_symbol('MKLDNN')
mod = Module(symbol=sym, label_names=[])
mod.bind(for_training=False,
data_shapes=[('data', data_shape)])

mod.init_params(mx.init.Normal(0.5))
arg_params, aux_params = mod.get_params()

excluded_sym_names = []

calib_data = mx.nd.random.uniform(shape=data_shape)
calib_data = NDArrayIter(data=calib_data)
calib_data = DummyIter(calib_data)
calib_layer = lambda name: name.endswith('_output')
qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym_sg,
arg_params=arg_params,
aux_params=aux_params,
ctx=mx.current_context(),
excluded_sym_names=excluded_sym_names,
quantized_dtype=out_type,
calib_mode='naive',
calib_data=calib_data,
calib_layer=calib_layer,
num_calib_examples=5)
qsym = qsym.get_backend_symbol('MKLDNN_POST_QUANTIZE')
check_qsym_forward(qsym, qarg_params, qaux_params, data_shape)

for qdtype in ['uint8', 'int8', 'auto']:
check_quantize_whole_model(qdtype)

@with_seed()
def check_fusion(sym, data_shape, attrs_op, name='conv', check_quantization=True):
op_name = config[name][OP_NAME]
Expand Down
95 changes: 95 additions & 0 deletions tests/python/quantization/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,101 @@ def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape):
for qdtype in ['int8', 'uint8']:
check_quantize_model(qdtype)

@with_seed()
def test_quantize_conv_with_forward():
def check_quantize_model(qdtype):
if is_test_for_native_cpu():
print('skipped testing test_quantize_model_with_forward for native cpu since it is not supported yet')
return
elif qdtype == 'int8' and is_test_for_mkldnn():
print('skipped testing test_quantize_model_with_forward for mkldnn cpu int8 since it is not supported yet')
return
elif qdtype == 'uint8' and is_test_for_gpu():
print('skipped testing test_quantize_model_with_forward for gpu uint8 since it is not supported yet')
return

def check_params(params, qparams, qsym=None):
if qsym is None:
assert len(params) == len(qparams)
for k, v in params.items():
assert k in qparams
assert same(v.asnumpy(), qparams[k].asnumpy())
else:
qparams_ground_truth = mx.contrib.quant._quantize_params(qsym, params, th_dict = {})
assert len(qparams) == len(qparams_ground_truth)
for k, v in qparams_ground_truth.items():
assert k in qparams
assert same(v.asnumpy(), qparams[k].asnumpy())

def check_qsym_calibrated(qsym):
attrs = qsym.attr_dict()
for k, v in attrs.items():
if k.find('requantize_') != -1:
assert 'min_calib_range' in v
assert 'max_calib_range' in v

def check_qsym_qdtype(qsym, qdtype):
attrs = qsym.attr_dict()
for k, v in attrs.items():
if k.find('_quantize') != -1:
assert 'out_type' in v
assert v['out_type'] == qdtype

def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape):
mod = mx.mod.Module(symbol=qsym, label_names=None, context=mx.current_context())
mod.bind(for_training=False,
data_shapes=[('data', data_shape)])
mod.set_params(qarg_params, qaux_params)
data = [mx.random.uniform(-1.0, 1.0, shape=shape) for _, shape in mod.data_shapes]
batch = mx.io.DataBatch(data, [])
mod.forward(batch, is_train=False)
for output in mod.get_outputs():
output.wait_to_read()

batch_size = 4
dshape = (batch_size, 4, 10, 10)
data = mx.sym.Variable('data')
sym = mx.sym.Convolution(data, kernel=(1, 1), num_filter=16, name='conv0')

mod = Module(symbol=sym, label_names=None)
mod.bind(data_shapes=[('data', dshape)])

mod.init_params()
arg_params, aux_params = mod.get_params()
excluded_sym_names = []

qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym,
arg_params=arg_params,
aux_params=aux_params,
excluded_sym_names=excluded_sym_names,
ctx=mx.current_context(),
quantized_dtype=qdtype,
calib_mode='none')
check_params(arg_params, qarg_params, qsym)
check_params(aux_params, qaux_params)
check_qsym_forward(qsym, qarg_params, qaux_params, dshape)

calib_data = mx.nd.random.uniform(shape=dshape)
calib_data = NDArrayIter(data=calib_data, batch_size=batch_size)
calib_data = DummyIter(calib_data)
qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym,
arg_params=arg_params,
aux_params=aux_params,
excluded_sym_names=excluded_sym_names,
ctx=mx.current_context(),
quantized_dtype=qdtype,
calib_mode='naive',
calib_data=calib_data,
num_calib_examples=20)
check_params(arg_params, qarg_params, qsym)
check_params(aux_params, qaux_params)
check_qsym_calibrated(qsym)
check_qsym_qdtype(qsym, qdtype)
check_qsym_forward(qsym, qarg_params, qaux_params, dshape)

for qdtype in ['uint8', 'int8']:
check_quantize_model(qdtype)

@with_seed()
def test_quantize_sym_with_calib():
sym = get_fp32_sym()
Expand Down