diff --git a/example/quantization/imagenet_gen_qsym_mkldnn.py b/example/quantization/imagenet_gen_qsym_mkldnn.py index c38019fbe7b9..938890bb75df 100644 --- a/example/quantization/imagenet_gen_qsym_mkldnn.py +++ b/example/quantization/imagenet_gen_qsym_mkldnn.py @@ -273,7 +273,9 @@ def save_params(fname, arg_params, aux_params, logger=None): logger.info('rgb_std = %s' % rgb_std) rgb_std = [float(i) for i in rgb_std.split(',')] std_args = {'std_r': rgb_std[0], 'std_g': rgb_std[1], 'std_b': rgb_std[2]} - + combine_mean_std = {} + combine_mean_std.update(mean_args) + combine_mean_std.update(std_args) if calib_mode == 'none': logger.info('Quantizing FP32 model %s' % args.model) qsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params, @@ -294,8 +296,7 @@ def save_params(fname, arg_params, aux_params, logger=None): shuffle=args.shuffle_dataset, shuffle_chunk_seed=args.shuffle_chunk_seed, seed=args.shuffle_seed, - **mean_args, - **std_args) + **combine_mean_std) qsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params, ctx=ctx, excluded_sym_names=excluded_sym_names, diff --git a/example/quantization/imagenet_inference.py b/example/quantization/imagenet_inference.py index 3fdb52f40cb2..0725165b0ca5 100644 --- a/example/quantization/imagenet_inference.py +++ b/example/quantization/imagenet_inference.py @@ -181,6 +181,9 @@ def benchmark_score(symbol_file, ctx, batch_size, num_batches, logger=None): logger.info('rgb_std = %s' % rgb_std) rgb_std = [float(i) for i in rgb_std.split(',')] std_args = {'std_r': rgb_std[0], 'std_g': rgb_std[1], 'std_b': rgb_std[2]} + combine_mean_std = {} + combine_mean_std.update(mean_args) + combine_mean_std.update(std_args) label_name = args.label_name logger.info('label_name = %s' % label_name) @@ -206,8 +209,7 @@ def benchmark_score(symbol_file, ctx, batch_size, num_batches, logger=None): shuffle=True, shuffle_chunk_seed=3982304, seed=48564309, - **mean_args, - **std_args) + **combine_mean_std) # loading model sym, arg_params, aux_params = load_model(symbol_file, param_file, logger)