diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 7f5b1ca29b1c..c9c58a9c9ba4 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -452,9 +452,11 @@ def quantize_model(sym, arg_params, aux_params, otherwise, no information of the layer's output will be collected. If not provided, all the layers' outputs that need requantization will be collected. logger : Object + A logging object for printing information during the process of quantization. Returns - `(qsym, qarg_params, aux_params)` + ------- + tuple A tuple of quantized symbol, quantized arg_params, and aux_params. ------- """ @@ -471,7 +473,8 @@ def quantize_model(sym, arg_params, aux_params, idx = nodes.list_outputs().index(sym_name + '_output') excluded_syms.append(nodes[idx]) logger.info('Quantizing symbol') - qsym = _quantize_symbol(sym, excluded_symbols=excluded_syms, offline_params=arg_params.keys()) + qsym = _quantize_symbol(sym, excluded_symbols=excluded_syms, + offline_params=list(arg_params.keys())) logger.info('Quantizing parameters') qarg_params = _quantize_params(qsym, arg_params)