Skip to content

Commit

Permalink
Fix the quantization script to support Python2 (apache#13700)
Browse files Browse the repository at this point in the history
* fix the quantization script to support python2

* Fix comments, fix similiar issue in imagenet_inference.py
  • Loading branch information
juliusshufan authored and rondogency committed Jan 9, 2019
1 parent b87b55c commit fc6c57c
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
7 changes: 4 additions & 3 deletions example/quantization/imagenet_gen_qsym_mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,9 @@ def save_params(fname, arg_params, aux_params, logger=None):
logger.info('rgb_std = %s' % rgb_std)
rgb_std = [float(i) for i in rgb_std.split(',')]
std_args = {'std_r': rgb_std[0], 'std_g': rgb_std[1], 'std_b': rgb_std[2]}

combine_mean_std = {}
combine_mean_std.update(mean_args)
combine_mean_std.update(std_args)
if calib_mode == 'none':
logger.info('Quantizing FP32 model %s' % args.model)
qsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params,
Expand All @@ -294,8 +296,7 @@ def save_params(fname, arg_params, aux_params, logger=None):
shuffle=args.shuffle_dataset,
shuffle_chunk_seed=args.shuffle_chunk_seed,
seed=args.shuffle_seed,
**mean_args,
**std_args)
**combine_mean_std)

qsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params,
ctx=ctx, excluded_sym_names=excluded_sym_names,
Expand Down
6 changes: 4 additions & 2 deletions example/quantization/imagenet_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ def benchmark_score(symbol_file, ctx, batch_size, num_batches, logger=None):
logger.info('rgb_std = %s' % rgb_std)
rgb_std = [float(i) for i in rgb_std.split(',')]
std_args = {'std_r': rgb_std[0], 'std_g': rgb_std[1], 'std_b': rgb_std[2]}
combine_mean_std = {}
combine_mean_std.update(mean_args)
combine_mean_std.update(std_args)

label_name = args.label_name
logger.info('label_name = %s' % label_name)
Expand All @@ -206,8 +209,7 @@ def benchmark_score(symbol_file, ctx, batch_size, num_batches, logger=None):
shuffle=True,
shuffle_chunk_seed=3982304,
seed=48564309,
**mean_args,
**std_args)
**combine_mean_std)

# loading model
sym, arg_params, aux_params = load_model(symbol_file, param_file, logger)
Expand Down

0 comments on commit fc6c57c

Please sign in to comment.