diff --git a/python/mxnet/quantization.py b/python/mxnet/quantization.py index 1b68d1819992..293dc4cd47d3 100644 --- a/python/mxnet/quantization.py +++ b/python/mxnet/quantization.py @@ -273,16 +273,14 @@ def _get_optimal_thresholds(nd_dict, num_bins=8001, num_quantized_bins=255, logg logger.info('Calculating optimal thresholds for quantization using KL divergence' ' with num_bins=%d and num_quantized_bins=%d' % (num_bins, num_quantized_bins)) th_dict = {} - layer_names = nd_dict.keys() - for name in layer_names: - assert name in nd_dict - min_val, max_val, min_divergence, opt_th = _get_optimal_threshold(nd_dict[name], num_bins=num_bins, + for k, v in nd_dict: + min_val, max_val, min_divergence, opt_th = _get_optimal_threshold(v, num_bins=num_bins, num_quantized_bins=num_quantized_bins) - del nd_dict[name] # release the memory of ndarray - th_dict[name] = (-opt_th, opt_th) + del v # release the memory of ndarray + th_dict[k] = (-opt_th, opt_th) if logger is not None: logger.info('layer=%s, min_val=%f, max_val=%f, min_divergence=%f, optimal_threshold=%f' - % (name, min_val, max_val, min_divergence, opt_th)) + % (k, min_val, max_val, min_divergence, opt_th)) return th_dict